mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
network_imporvements branch merge
This commit is contained in:
@@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.middlewares.middleware import
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
||||
|
||||
|
||||
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
@@ -80,6 +80,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
return ret_dict
|
||||
|
||||
self.available_return_types = fill_return_types()
|
||||
self.is_training = None
|
||||
|
||||
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
|
||||
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
|
||||
@@ -161,7 +162,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy)
|
||||
return module
|
||||
|
||||
def get_output_head(self, head_params: HeadParameters, head_idx: int, loss_weight: float=1.):
|
||||
def get_output_head(self, head_params: HeadParameters, head_idx: int):
|
||||
"""
|
||||
Given a head type, creates the head and returns it
|
||||
:param head_params: the parameters of the head to create
|
||||
@@ -176,7 +177,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function)
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'loss_weight': loss_weight, 'is_local': self.network_is_local})
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
|
||||
def get_model(self):
|
||||
# validate the configuration
|
||||
@@ -189,11 +190,10 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
if self.network_parameters.middleware_parameters is None:
|
||||
raise ValueError("Exactly one middleware type should be defined")
|
||||
|
||||
if len(self.network_parameters.loss_weights) == 0:
|
||||
raise ValueError("At least one loss weight should be defined")
|
||||
|
||||
if len(self.network_parameters.heads_parameters) != len(self.network_parameters.loss_weights):
|
||||
raise ValueError("Number of loss weights should match the number of output types")
|
||||
# ops for defining the training / testing phase
|
||||
self.is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||
self.is_training_placeholder = tf.placeholder("bool")
|
||||
self.assign_is_training = tf.assign(self.is_training, self.is_training_placeholder)
|
||||
|
||||
for network_idx in range(self.num_networks):
|
||||
with tf.variable_scope('network_{}'.format(network_idx)):
|
||||
@@ -245,28 +245,27 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
head_count = 0
|
||||
for head_idx in range(self.num_heads_per_network):
|
||||
for head_copy_idx in range(self.network_parameters.num_output_head_copies):
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
# if we use separate networks per head, then the head type corresponds top the network idx
|
||||
head_type_idx = network_idx
|
||||
head_count = network_idx
|
||||
else:
|
||||
# if we use a single network with multiple embedders, then the head type is the current head idx
|
||||
head_type_idx = head_idx
|
||||
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
# if we use separate networks per head, then the head type corresponds to the network idx
|
||||
head_type_idx = network_idx
|
||||
head_count = network_idx
|
||||
else:
|
||||
# if we use a single network with multiple embedders, then the head type is the current head idx
|
||||
head_type_idx = head_idx
|
||||
head_params = self.network_parameters.heads_parameters[head_type_idx]
|
||||
|
||||
for head_copy_idx in range(head_params.num_output_head_copies):
|
||||
# create output head and add it to the output heads list
|
||||
self.output_heads.append(
|
||||
self.get_output_head(self.network_parameters.heads_parameters[head_type_idx],
|
||||
head_idx*self.network_parameters.num_output_head_copies + head_copy_idx,
|
||||
self.network_parameters.loss_weights[head_type_idx])
|
||||
self.get_output_head(head_params,
|
||||
head_idx*head_params.num_output_head_copies + head_copy_idx)
|
||||
)
|
||||
|
||||
# rescale the gradients from the head
|
||||
self.gradients_from_head_rescalers.append(
|
||||
tf.get_variable('gradients_from_head_{}-{}_rescalers'.format(head_idx, head_copy_idx),
|
||||
initializer=float(
|
||||
self.network_parameters.rescale_gradient_from_head_by_factor[head_count]
|
||||
),
|
||||
initializer=float(head_params.rescale_gradient_from_head_by_factor),
|
||||
dtype=tf.float32))
|
||||
|
||||
self.gradients_from_head_rescalers_placeholders.append(
|
||||
@@ -344,4 +343,46 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
else:
|
||||
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
|
||||
|
||||
def __str__(self):
|
||||
result = []
|
||||
|
||||
for network in range(self.num_networks):
|
||||
network_structure = []
|
||||
|
||||
# embedder
|
||||
for embedder in self.input_embedders:
|
||||
network_structure.append("Input Embedder: {}".format(embedder.name))
|
||||
network_structure.append(indent_string(str(embedder)))
|
||||
|
||||
if len(self.input_embedders) > 1:
|
||||
network_structure.append("{} ({})".format(self.network_parameters.embedding_merger_type.name,
|
||||
", ".join(["{} embedding".format(e.name) for e in self.input_embedders])))
|
||||
|
||||
# middleware
|
||||
network_structure.append("Middleware:")
|
||||
network_structure.append(indent_string(str(self.middleware)))
|
||||
|
||||
# head
|
||||
if self.network_parameters.use_separate_networks_per_head:
|
||||
heads = range(network, network+1)
|
||||
else:
|
||||
heads = range(0, len(self.output_heads))
|
||||
|
||||
for head_idx in heads:
|
||||
head = self.output_heads[head_idx]
|
||||
head_params = self.network_parameters.heads_parameters[head_idx]
|
||||
if head_params.num_output_head_copies > 1:
|
||||
network_structure.append("Output Head: {} (num copies = {})".format(head.name, head_params.num_output_head_copies))
|
||||
else:
|
||||
network_structure.append("Output Head: {}".format(head.name))
|
||||
network_structure.append(indent_string(str(head)))
|
||||
|
||||
# finalize network
|
||||
if self.num_networks > 1:
|
||||
result.append("Sub-network for head: {}".format(self.output_heads[network].name))
|
||||
result.append(indent_string('\n'.join(network_structure)))
|
||||
else:
|
||||
result.append('\n'.join(network_structure))
|
||||
|
||||
result = '\n'.join(result)
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user