mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Move embedder, middleware, and head parameters to framework agnostic modules. (#45)
Part of #28
This commit is contained in:
committed by
Scott Leishman
parent
16b3e99f37
commit
a888226641
@@ -20,10 +20,10 @@ from typing import Dict
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedderParameters
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.head_parameters import HeadParameters
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.tensorflow_components.architecture import TensorFlowArchitecture
|
||||
from rl_coach.architectures.tensorflow_components.heads.head import HeadParameters
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import MiddlewareParameters
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
@@ -136,15 +136,17 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
|
||||
.format(input_name, allowed_inputs.keys()))
|
||||
|
||||
type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
type = "image"
|
||||
mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder'}
|
||||
|
||||
embedder_path = 'rl_coach.architectures.tensorflow_components.embedders.' + embedder_params.path[type]
|
||||
emb_type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
emb_type = "image"
|
||||
|
||||
embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type]
|
||||
embedder_params_copy = copy.copy(embedder_params)
|
||||
embedder_params_copy.activation_function = self.get_activation_function(embedder_params.activation_function)
|
||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[type]
|
||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[type]
|
||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
||||
embedder_params_copy.name = input_name
|
||||
module = dynamic_import_and_instantiate_module_from_params(embedder_params_copy,
|
||||
path=embedder_path,
|
||||
@@ -157,25 +159,25 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
:param middleware_params: the paramaeters of the middleware class
|
||||
:return: the middleware instance
|
||||
"""
|
||||
mod_name = middleware_params.parameterized_class_name
|
||||
middleware_path = 'rl_coach.architectures.tensorflow_components.middlewares:' + mod_name
|
||||
middleware_params_copy = copy.copy(middleware_params)
|
||||
middleware_params_copy.activation_function = self.get_activation_function(middleware_params.activation_function)
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy)
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
||||
return module
|
||||
|
||||
def get_output_head(self, head_params: HeadParameters, head_idx: int):
|
||||
"""
|
||||
Given a head type, creates the head and returns it
|
||||
:param head_params: the parameters of the head to create
|
||||
:param head_type: the path to the class of the head under the embedders directory or a full path to a head class.
|
||||
the path should be in the following structure: <module_path>:<class_path>
|
||||
:param head_idx: the head index
|
||||
:param loss_weight: the weight to assign for the embedders loss
|
||||
:return: the head
|
||||
"""
|
||||
|
||||
mod_name = head_params.parameterized_class_name
|
||||
head_path = 'rl_coach.architectures.tensorflow_components.heads:' + mod_name
|
||||
head_params_copy = copy.copy(head_params)
|
||||
head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function)
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, extra_kwargs={
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
|
||||
|
||||
Reference in New Issue
Block a user