From 9e66bb653ee53f46534e8e870ab7fe9639907662 Mon Sep 17 00:00:00 2001 From: Ryan Peach Date: Wed, 5 Dec 2018 04:40:06 -0500 Subject: [PATCH] Enable creating custom tensorflow heads, embedders, and middleware. (#135) Allowing components to have a path property. --- rl_coach/architectures/embedder_parameters.py | 4 ++++ rl_coach/architectures/head_parameters.py | 5 +++++ rl_coach/architectures/middleware_parameters.py | 4 ++++ .../tensorflow_components/general_network.py | 8 +++----- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/rl_coach/architectures/embedder_parameters.py b/rl_coach/architectures/embedder_parameters.py index 679cac9..45269ef 100644 --- a/rl_coach/architectures/embedder_parameters.py +++ b/rl_coach/architectures/embedder_parameters.py @@ -18,6 +18,7 @@ from typing import List, Union from rl_coach.base_parameters import EmbedderScheme, NetworkComponentParameters +MOD_NAMES = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'TensorEmbedder'} class InputEmbedderParameters(NetworkComponentParameters): def __init__(self, activation_function: str='relu', scheme: Union[List, EmbedderScheme]=EmbedderScheme.Medium, @@ -39,3 +40,6 @@ class InputEmbedderParameters(NetworkComponentParameters): self.input_clipping = input_clipping self.name = name self.is_training = is_training + + def path(self, emb_type): + return 'rl_coach.architectures.tensorflow_components.embedders:' + MOD_NAMES[emb_type] diff --git a/rl_coach/architectures/head_parameters.py b/rl_coach/architectures/head_parameters.py index e29d656..0aab2b6 100644 --- a/rl_coach/architectures/head_parameters.py +++ b/rl_coach/architectures/head_parameters.py @@ -31,6 +31,11 @@ class HeadParameters(NetworkComponentParameters): self.loss_weight = loss_weight self.parameterized_class_name = parameterized_class_name + @property + def path(self): + return 'rl_coach.architectures.tensorflow_components.heads:' + self.parameterized_class_name + + class PPOHeadParameters(HeadParameters): def __init__(self, activation_function: str ='tanh', name: str='ppo_head_params', diff --git a/rl_coach/architectures/middleware_parameters.py b/rl_coach/architectures/middleware_parameters.py index 40533cd..73bb4bd 100644 --- a/rl_coach/architectures/middleware_parameters.py +++ b/rl_coach/architectures/middleware_parameters.py @@ -32,6 +32,10 @@ class MiddlewareParameters(NetworkComponentParameters): self.is_training = is_training self.parameterized_class_name = parameterized_class_name + @property + def path(self): + return 'rl_coach.architectures.tensorflow_components.middlewares:' + self.parameterized_class_name + class FCMiddlewareParameters(MiddlewareParameters): def __init__(self, activation_function='relu', diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index ca72bd7..5c70707 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -174,15 +174,13 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" .format(input_name, allowed_inputs.keys())) - mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'TensorEmbedder'} - emb_type = "vector" if isinstance(allowed_inputs[input_name], TensorObservationSpace): emb_type = "tensor" elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): emb_type = "image" - embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type] + embedder_path = embedder_params.path(emb_type) embedder_params_copy = copy.copy(embedder_params) embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] @@ -200,7 +198,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): :return: the middleware instance """ mod_name = middleware_params.parameterized_class_name - middleware_path = 'rl_coach.architectures.tensorflow_components.middlewares:' + mod_name + middleware_path = middleware_params.path middleware_params_copy = copy.copy(middleware_params) middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) @@ -214,7 +212,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): :return: the head """ mod_name = head_params.parameterized_class_name - head_path = 'rl_coach.architectures.tensorflow_components.heads:' + mod_name + head_path = head_params.path head_params_copy = copy.copy(head_params) head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={