1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

network_imporvements branch merge

This commit is contained in:
Shadi Endrawis
2018-10-02 13:41:46 +03:00
parent 72ea933384
commit 51726a5b80
110 changed files with 1639 additions and 1161 deletions

View File

@@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.middlewares.middleware import
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.core_types import PredictionType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
class GeneralTensorFlowNetwork(TensorFlowArchitecture):
@@ -80,6 +80,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
return ret_dict
self.available_return_types = fill_return_types()
self.is_training = None
def predict_with_prediction_type(self, states: Dict[str, np.ndarray],
prediction_type: PredictionType) -> Dict[str, np.ndarray]:
@@ -161,7 +162,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy)
return module
def get_output_head(self, head_params: HeadParameters, head_idx: int, loss_weight: float=1.):
def get_output_head(self, head_params: HeadParameters, head_idx: int):
"""
Given a head type, creates the head and returns it
:param head_params: the parameters of the head to create
@@ -176,7 +177,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
head_params_copy.activation_function = self.get_activation_function(head_params_copy.activation_function)
return dynamic_import_and_instantiate_module_from_params(head_params_copy, extra_kwargs={
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
'head_idx': head_idx, 'loss_weight': loss_weight, 'is_local': self.network_is_local})
'head_idx': head_idx, 'is_local': self.network_is_local})
def get_model(self):
# validate the configuration
@@ -189,11 +190,10 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
if self.network_parameters.middleware_parameters is None:
raise ValueError("Exactly one middleware type should be defined")
if len(self.network_parameters.loss_weights) == 0:
raise ValueError("At least one loss weight should be defined")
if len(self.network_parameters.heads_parameters) != len(self.network_parameters.loss_weights):
raise ValueError("Number of loss weights should match the number of output types")
# ops for defining the training / testing phase
self.is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
self.is_training_placeholder = tf.placeholder("bool")
self.assign_is_training = tf.assign(self.is_training, self.is_training_placeholder)
for network_idx in range(self.num_networks):
with tf.variable_scope('network_{}'.format(network_idx)):
@@ -245,28 +245,27 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
head_count = 0
for head_idx in range(self.num_heads_per_network):
for head_copy_idx in range(self.network_parameters.num_output_head_copies):
if self.network_parameters.use_separate_networks_per_head:
# if we use separate networks per head, then the head type corresponds top the network idx
head_type_idx = network_idx
head_count = network_idx
else:
# if we use a single network with multiple embedders, then the head type is the current head idx
head_type_idx = head_idx
if self.network_parameters.use_separate_networks_per_head:
# if we use separate networks per head, then the head type corresponds to the network idx
head_type_idx = network_idx
head_count = network_idx
else:
# if we use a single network with multiple embedders, then the head type is the current head idx
head_type_idx = head_idx
head_params = self.network_parameters.heads_parameters[head_type_idx]
for head_copy_idx in range(head_params.num_output_head_copies):
# create output head and add it to the output heads list
self.output_heads.append(
self.get_output_head(self.network_parameters.heads_parameters[head_type_idx],
head_idx*self.network_parameters.num_output_head_copies + head_copy_idx,
self.network_parameters.loss_weights[head_type_idx])
self.get_output_head(head_params,
head_idx*head_params.num_output_head_copies + head_copy_idx)
)
# rescale the gradients from the head
self.gradients_from_head_rescalers.append(
tf.get_variable('gradients_from_head_{}-{}_rescalers'.format(head_idx, head_copy_idx),
initializer=float(
self.network_parameters.rescale_gradient_from_head_by_factor[head_count]
),
initializer=float(head_params.rescale_gradient_from_head_by_factor),
dtype=tf.float32))
self.gradients_from_head_rescalers_placeholders.append(
@@ -344,4 +343,46 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
else:
raise Exception("{} is not a valid optimizer type".format(self.network_parameters.optimizer_type))
def __str__(self):
result = []
for network in range(self.num_networks):
network_structure = []
# embedder
for embedder in self.input_embedders:
network_structure.append("Input Embedder: {}".format(embedder.name))
network_structure.append(indent_string(str(embedder)))
if len(self.input_embedders) > 1:
network_structure.append("{} ({})".format(self.network_parameters.embedding_merger_type.name,
", ".join(["{} embedding".format(e.name) for e in self.input_embedders])))
# middleware
network_structure.append("Middleware:")
network_structure.append(indent_string(str(self.middleware)))
# head
if self.network_parameters.use_separate_networks_per_head:
heads = range(network, network+1)
else:
heads = range(0, len(self.output_heads))
for head_idx in heads:
head = self.output_heads[head_idx]
head_params = self.network_parameters.heads_parameters[head_idx]
if head_params.num_output_head_copies > 1:
network_structure.append("Output Head: {} (num copies = {})".format(head.name, head_params.num_output_head_copies))
else:
network_structure.append("Output Head: {}".format(head.name))
network_structure.append(indent_string(str(head)))
# finalize network
if self.num_networks > 1:
result.append("Sub-network for head: {}".format(self.output_heads[network].name))
result.append(indent_string('\n'.join(network_structure)))
else:
result.append('\n'.join(network_structure))
result = '\n'.join(result)
return result