diff --git a/rl_coach/architectures/tensorflow_components/heads/__init__.py b/rl_coach/architectures/tensorflow_components/heads/__init__.py index daf5492..9e96553 100644 --- a/rl_coach/architectures/tensorflow_components/heads/__init__.py +++ b/rl_coach/architectures/tensorflow_components/heads/__init__.py @@ -12,6 +12,7 @@ from .quantile_regression_q_head import QuantileRegressionQHead from .rainbow_q_head import RainbowQHead from .v_head import VHead from .acer_policy_head import ACERPolicyHead +from .cil_head import RegressionHead __all__ = [ 'CategoricalQHead', @@ -27,5 +28,6 @@ __all__ = [ 'QuantileRegressionQHead', 'RainbowQHead', 'VHead', - 'ACERPolicyHead' + 'ACERPolicyHead', + 'RegressionHead' ]