diff --git a/rl_coach/architectures/tensorflow_components/heads/q_head.py b/rl_coach/architectures/tensorflow_components/heads/q_head.py index 32c6946..eedec5b 100644 --- a/rl_coach/architectures/tensorflow_components/heads/q_head.py +++ b/rl_coach/architectures/tensorflow_components/heads/q_head.py @@ -34,6 +34,12 @@ class QHead(Head): self.num_actions = 1 elif isinstance(self.spaces.action, DiscreteActionSpace): self.num_actions = len(self.spaces.action.actions) + else: + raise ValueError( + 'QHead does not support action spaces of type: {class_name}'.format( + class_name=self.spaces.action.__class__.__name__, + ) + ) self.return_type = QActionStateValue if agent_parameters.network_wrappers[self.network_name].replace_mse_with_huber_loss: self.loss_type = tf.losses.huber_loss @@ -49,5 +55,3 @@ class QHead(Head): "Dense (num outputs = {})".format(self.num_actions) ] return '\n'.join(result) - -