mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Removed tensorflow specific code in presets (#59)
* Add generic layer specification for using in presets * Modify presets to use the generic scheme
This commit is contained in:
committed by
Gal Leibovich
parent
811152126c
commit
93571306c3
@@ -24,6 +24,7 @@ from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.head_parameters import HeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
|
||||
from rl_coach.architectures.tensorflow_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
@@ -99,27 +100,6 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
return ret_dict
|
||||
|
||||
@staticmethod
|
||||
def get_activation_function(activation_function_string: str):
|
||||
"""
|
||||
Map the activation function from a string to the tensorflow framework equivalent
|
||||
:param activation_function_string: the type of the activation function
|
||||
:return: the tensorflow activation function
|
||||
"""
|
||||
activation_functions = {
|
||||
'relu': tf.nn.relu,
|
||||
'tanh': tf.nn.tanh,
|
||||
'sigmoid': tf.nn.sigmoid,
|
||||
'elu': tf.nn.elu,
|
||||
'selu': tf.nn.selu,
|
||||
'leaky_relu': tf.nn.leaky_relu,
|
||||
'none': None
|
||||
}
|
||||
assert activation_function_string in activation_functions.keys(), \
|
||||
"Activation function must be one of the following {}. instead it was: {}"\
|
||||
.format(activation_functions.keys(), activation_function_string)
|
||||
return activation_functions[activation_function_string]
|
||||
|
||||
def get_input_embedder(self, input_name: str, embedder_params: InputEmbedderParameters):
|
||||
"""
|
||||
Given an input embedder parameters class, creates the input embedder and returns it
|
||||
@@ -144,7 +124,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type]
|
||||
embedder_params_copy = copy.copy(embedder_params)
|
||||
embedder_params_copy.activation_function = self.get_activation_function(embedder_params.activation_function)
|
||||
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
|
||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
||||
embedder_params_copy.name = input_name
|
||||
@@ -162,7 +142,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
mod_name = middleware_params.parameterized_class_name
|
||||
middleware_path = 'rl_coach.architectures.tensorflow_components.middlewares:' + mod_name
|
||||
middleware_params_copy = copy.copy(middleware_params)
|
||||
middleware_params_copy.activation_function = self.get_activation_function(middleware_params.activation_function)
|
||||
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
||||
return module
|
||||
|
||||
@@ -176,7 +156,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
mod_name = head_params.parameterized_class_name
|
||||
head_path = 'rl_coach.architectures.tensorflow_components.heads:' + mod_name
|
||||
head_params_copy = copy.copy(head_params)
|
||||
head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function)
|
||||
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
|
||||
Reference in New Issue
Block a user