batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
@@ -38,6 +38,7 @@ The environments that were used for testing include:
|
|||||||
|**[Clipped PPO](clipped_ppo)** |  |Mujoco | |
|
|**[Clipped PPO](clipped_ppo)** |  |Mujoco | |
|
||||||
|**[DDPG](ddpg)** |  |Mujoco | |
|
|**[DDPG](ddpg)** |  |Mujoco | |
|
||||||
|**[SAC](sac)** |  |Mujoco | |
|
|**[SAC](sac)** |  |Mujoco | |
|
||||||
|
|**[TD3](td3)** |  |Mujoco | |
|
||||||
|**[NEC](nec)** |  |Atari | |
|
|**[NEC](nec)** |  |Atari | |
|
||||||
|**[HER](ddpg_her)** |  |Fetch | |
|
|**[HER](ddpg_her)** |  |Fetch | |
|
||||||
|**[DFP](dfp)** |  |Doom | Doom Battle was not verified |
|
|**[DFP](dfp)** |  |Doom | Doom Battle was not verified |
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 135 KiB After Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 89 KiB After Width: | Height: | Size: 109 KiB |
|
Before Width: | Height: | Size: 111 KiB After Width: | Height: | Size: 124 KiB |
|
Before Width: | Height: | Size: 113 KiB After Width: | Height: | Size: 82 KiB |
|
Before Width: | Height: | Size: 104 KiB After Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 127 KiB After Width: | Height: | Size: 135 KiB |
|
Before Width: | Height: | Size: 70 KiB After Width: | Height: | Size: 98 KiB |
|
Before Width: | Height: | Size: 82 KiB After Width: | Height: | Size: 117 KiB |
|
Before Width: | Height: | Size: 119 KiB After Width: | Height: | Size: 115 KiB |
@@ -34,9 +34,9 @@ from rl_coach.spaces import BoxActionSpace, GoalsSpace
|
|||||||
|
|
||||||
|
|
||||||
class DDPGCriticNetworkParameters(NetworkParameters):
|
class DDPGCriticNetworkParameters(NetworkParameters):
|
||||||
def __init__(self):
|
def __init__(self, use_batchnorm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True),
|
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
|
||||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
||||||
self.middleware_parameters = FCMiddlewareParameters()
|
self.middleware_parameters = FCMiddlewareParameters()
|
||||||
self.heads_parameters = [DDPGVHeadParameters()]
|
self.heads_parameters = [DDPGVHeadParameters()]
|
||||||
@@ -53,11 +53,11 @@ class DDPGCriticNetworkParameters(NetworkParameters):
|
|||||||
|
|
||||||
|
|
||||||
class DDPGActorNetworkParameters(NetworkParameters):
|
class DDPGActorNetworkParameters(NetworkParameters):
|
||||||
def __init__(self):
|
def __init__(self, use_batchnorm=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
|
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
|
||||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
|
self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
|
||||||
self.heads_parameters = [DDPGActorHeadParameters()]
|
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
|
||||||
self.optimizer_type = 'Adam'
|
self.optimizer_type = 'Adam'
|
||||||
self.batch_size = 64
|
self.batch_size = 64
|
||||||
self.adam_optimizer_beta2 = 0.999
|
self.adam_optimizer_beta2 = 0.999
|
||||||
@@ -109,12 +109,12 @@ class DDPGAlgorithmParameters(AlgorithmParameters):
|
|||||||
|
|
||||||
|
|
||||||
class DDPGAgentParameters(AgentParameters):
|
class DDPGAgentParameters(AgentParameters):
|
||||||
def __init__(self):
|
def __init__(self, use_batchnorm=False):
|
||||||
super().__init__(algorithm=DDPGAlgorithmParameters(),
|
super().__init__(algorithm=DDPGAlgorithmParameters(),
|
||||||
exploration=OUProcessParameters(),
|
exploration=OUProcessParameters(),
|
||||||
memory=EpisodicExperienceReplayParameters(),
|
memory=EpisodicExperienceReplayParameters(),
|
||||||
networks=OrderedDict([("actor", DDPGActorNetworkParameters()),
|
networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
|
||||||
("critic", DDPGCriticNetworkParameters())]))
|
("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self):
|
||||||
@@ -170,7 +170,9 @@ class DDPGAgent(ActorCriticAgent):
|
|||||||
# train the critic
|
# train the critic
|
||||||
critic_inputs = copy.copy(batch.states(critic_keys))
|
critic_inputs = copy.copy(batch.states(critic_keys))
|
||||||
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
|
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
|
||||||
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
|
|
||||||
|
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||||
|
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
|
||||||
total_loss, losses, unclipped_grads = result[:3]
|
total_loss, losses, unclipped_grads = result[:3]
|
||||||
|
|
||||||
# apply the gradients from the critic to the actor
|
# apply the gradients from the critic to the actor
|
||||||
@@ -179,11 +181,12 @@ class DDPGAgent(ActorCriticAgent):
|
|||||||
outputs=actor.online_network.weighted_gradients[0],
|
outputs=actor.online_network.weighted_gradients[0],
|
||||||
initial_feed_dict=initial_feed_dict)
|
initial_feed_dict=initial_feed_dict)
|
||||||
|
|
||||||
|
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||||
if actor.has_global:
|
if actor.has_global:
|
||||||
actor.apply_gradients_to_global_network(gradients)
|
actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||||
actor.update_online_network()
|
actor.update_online_network()
|
||||||
else:
|
else:
|
||||||
actor.apply_gradients_to_online_network(gradients)
|
actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||||
|
|
||||||
return total_loss, losses, unclipped_grads
|
return total_loss, losses, unclipped_grads
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from rl_coach.base_parameters import NetworkComponentParameters
|
|||||||
class HeadParameters(NetworkComponentParameters):
|
class HeadParameters(NetworkComponentParameters):
|
||||||
def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head',
|
def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head',
|
||||||
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
|
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
|
||||||
loss_weight: float=1.0, dense_layer=None):
|
loss_weight: float=1.0, dense_layer=None, is_training=False):
|
||||||
super().__init__(dense_layer=dense_layer)
|
super().__init__(dense_layer=dense_layer)
|
||||||
self.activation_function = activation_function
|
self.activation_function = activation_function
|
||||||
self.name = name
|
self.name = name
|
||||||
@@ -30,6 +30,7 @@ class HeadParameters(NetworkComponentParameters):
|
|||||||
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
|
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
|
||||||
self.loss_weight = loss_weight
|
self.loss_weight = loss_weight
|
||||||
self.parameterized_class_name = parameterized_class_name
|
self.parameterized_class_name = parameterized_class_name
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def path(self):
|
def path(self):
|
||||||
|
|||||||
@@ -124,31 +124,37 @@ class NetworkWrapper(object):
|
|||||||
if self.global_network:
|
if self.global_network:
|
||||||
self.online_network.set_weights(self.global_network.get_weights(), rate)
|
self.online_network.set_weights(self.global_network.get_weights(), rate)
|
||||||
|
|
||||||
def apply_gradients_to_global_network(self, gradients=None):
|
def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
|
||||||
"""
|
"""
|
||||||
Apply gradients from the online network on the global network
|
Apply gradients from the online network on the global network
|
||||||
|
|
||||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||||
|
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||||
|
update ops also requires the inputs)
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if gradients is None:
|
if gradients is None:
|
||||||
gradients = self.online_network.accumulated_gradients
|
gradients = self.online_network.accumulated_gradients
|
||||||
if self.network_parameters.shared_optimizer:
|
if self.network_parameters.shared_optimizer:
|
||||||
self.global_network.apply_gradients(gradients)
|
self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||||
else:
|
else:
|
||||||
self.online_network.apply_gradients(gradients)
|
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||||
|
|
||||||
def apply_gradients_to_online_network(self, gradients=None):
|
def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
|
||||||
"""
|
"""
|
||||||
Apply gradients from the online network on itself
|
Apply gradients from the online network on itself
|
||||||
|
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||||
|
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||||
|
update ops also requires the inputs)
|
||||||
|
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if gradients is None:
|
if gradients is None:
|
||||||
gradients = self.online_network.accumulated_gradients
|
gradients = self.online_network.accumulated_gradients
|
||||||
self.online_network.apply_gradients(gradients)
|
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||||
|
|
||||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
|
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
|
||||||
|
use_inputs_for_apply_gradients=False):
|
||||||
"""
|
"""
|
||||||
A generic training function that enables multi-threading training using a global network if necessary.
|
A generic training function that enables multi-threading training using a global network if necessary.
|
||||||
|
|
||||||
@@ -157,14 +163,20 @@ class NetworkWrapper(object):
|
|||||||
:param additional_fetches: Any additional tensor the user wants to fetch
|
:param additional_fetches: Any additional tensor the user wants to fetch
|
||||||
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
|
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
|
||||||
error of this sample. If it is not given, the samples losses won't be scaled
|
error of this sample. If it is not given, the samples losses won't be scaled
|
||||||
|
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
|
||||||
|
(e.g. for incorporating batchnorm update ops)
|
||||||
:return: The loss of the training iteration
|
:return: The loss of the training iteration
|
||||||
"""
|
"""
|
||||||
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
||||||
importance_weights=importance_weights, no_accumulation=True)
|
importance_weights=importance_weights, no_accumulation=True)
|
||||||
|
if use_inputs_for_apply_gradients:
|
||||||
|
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
|
||||||
|
else:
|
||||||
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def apply_gradients_and_sync_networks(self, reset_gradients=True):
|
def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
|
||||||
"""
|
"""
|
||||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||||
networks if necessary
|
networks if necessary
|
||||||
@@ -173,17 +185,22 @@ class NetworkWrapper(object):
|
|||||||
the network. this is useful when the accumulated gradients are overwritten instead
|
the network. this is useful when the accumulated gradients are overwritten instead
|
||||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||||
complexity for this function by around 10%
|
complexity for this function by around 10%
|
||||||
|
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||||
|
update ops also requires the inputs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if self.global_network:
|
if self.global_network:
|
||||||
self.apply_gradients_to_global_network()
|
self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
|
||||||
if reset_gradients:
|
if reset_gradients:
|
||||||
self.online_network.reset_accumulated_gradients()
|
self.online_network.reset_accumulated_gradients()
|
||||||
self.update_online_network()
|
self.update_online_network()
|
||||||
else:
|
else:
|
||||||
if reset_gradients:
|
if reset_gradients:
|
||||||
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
|
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
|
||||||
|
additional_inputs=additional_inputs)
|
||||||
else:
|
else:
|
||||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
self.online_network.apply_gradients(self.online_network.accumulated_gradients,
|
||||||
|
additional_inputs=additional_inputs)
|
||||||
|
|
||||||
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -270,6 +270,9 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
elif self.network_is_trainable:
|
elif self.network_is_trainable:
|
||||||
# not any of the above but is trainable? -> create an operation for applying the gradients to
|
# not any of the above but is trainable? -> create an operation for applying the gradients to
|
||||||
# this network weights
|
# this network weights
|
||||||
|
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
|
||||||
|
|
||||||
|
with tf.control_dependencies(update_ops):
|
||||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||||
|
|
||||||
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
|
|
||||||
return feed_dict
|
return feed_dict
|
||||||
|
|
||||||
def apply_and_reset_gradients(self, gradients, scaler=1.):
|
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||||
"""
|
"""
|
||||||
Applies the given gradients to the network weights and resets the accumulation placeholder
|
Applies the given gradients to the network weights and resets the accumulation placeholder
|
||||||
:param gradients: The gradients to use for the update
|
:param gradients: The gradients to use for the update
|
||||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||||
|
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||||
|
update ops also requires the inputs)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.apply_gradients(gradients, scaler)
|
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
|
||||||
self.reset_accumulated_gradients()
|
self.reset_accumulated_gradients()
|
||||||
|
|
||||||
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
|
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
|
||||||
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
|
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
|
||||||
self.sess.run(self.release_init)
|
self.sess.run(self.release_init)
|
||||||
|
|
||||||
def apply_gradients(self, gradients, scaler=1.):
|
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||||
"""
|
"""
|
||||||
Applies the given gradients to the network weights
|
Applies the given gradients to the network weights
|
||||||
:param gradients: The gradients to use for the update
|
:param gradients: The gradients to use for the update
|
||||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
||||||
The gradients will be MULTIPLIED by this factor
|
The gradients will be MULTIPLIED by this factor
|
||||||
|
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||||
|
update ops also requires the inputs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||||
if hasattr(self, 'global_step') and not self.network_is_local:
|
if hasattr(self, 'global_step') and not self.network_is_local:
|
||||||
self.sess.run(self.inc_step)
|
self.sess.run(self.inc_step)
|
||||||
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
# async distributed training / distributed training with independent optimizer
|
# async distributed training / distributed training with independent optimizer
|
||||||
# / non-distributed training - just apply the gradients
|
# / non-distributed training - just apply the gradients
|
||||||
feed_dict = dict(zip(self.weights_placeholders, gradients))
|
feed_dict = dict(zip(self.weights_placeholders, gradients))
|
||||||
|
if additional_inputs is not None:
|
||||||
|
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
|
||||||
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
|
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
|
||||||
|
|
||||||
# release barrier
|
# release barrier
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
|
|
||||||
embedder_path = embedder_params.path(emb_type)
|
embedder_path = embedder_params.path(emb_type)
|
||||||
embedder_params_copy = copy.copy(embedder_params)
|
embedder_params_copy = copy.copy(embedder_params)
|
||||||
|
embedder_params_copy.is_training = self.is_training
|
||||||
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
|
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
|
||||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
||||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
||||||
@@ -204,6 +205,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
middleware_path = middleware_params.path
|
middleware_path = middleware_params.path
|
||||||
middleware_params_copy = copy.copy(middleware_params)
|
middleware_params_copy = copy.copy(middleware_params)
|
||||||
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
|
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
|
||||||
|
middleware_params_copy.is_training = self.is_training
|
||||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@@ -218,6 +220,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
head_path = head_params.path
|
head_path = head_params.path
|
||||||
head_params_copy = copy.copy(head_params)
|
head_params_copy = copy.copy(head_params)
|
||||||
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
|
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
|
||||||
|
head_params_copy.is_training = self.is_training
|
||||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
||||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||||
@@ -339,7 +342,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
|||||||
head_count += 1
|
head_count += 1
|
||||||
|
|
||||||
# model weights
|
# model weights
|
||||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
if not self.distributed_training or self.network_is_global:
|
||||||
|
self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if
|
||||||
|
'global_step' not in var.name]
|
||||||
|
else:
|
||||||
|
self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)]
|
||||||
|
|
||||||
# Losses
|
# Losses
|
||||||
self.losses = tf.losses.get_losses(self.full_name)
|
self.losses = tf.losses.get_losses(self.full_name)
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
|
|||||||
class DDPGActor(Head):
|
class DDPGActor(Head):
|
||||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
||||||
batchnorm: bool=True, dense_layer=Dense):
|
batchnorm: bool=True, dense_layer=Dense, is_training=False):
|
||||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||||
dense_layer=dense_layer)
|
dense_layer=dense_layer, is_training=is_training)
|
||||||
self.name = 'ddpg_actor_head'
|
self.name = 'ddpg_actor_head'
|
||||||
self.return_type = ActionProbabilities
|
self.return_type = ActionProbabilities
|
||||||
|
|
||||||
@@ -50,7 +50,7 @@ class DDPGActor(Head):
|
|||||||
batchnorm=self.batchnorm,
|
batchnorm=self.batchnorm,
|
||||||
activation_function=self.activation_function,
|
activation_function=self.activation_function,
|
||||||
dropout_rate=0,
|
dropout_rate=0,
|
||||||
is_training=False,
|
is_training=self.is_training,
|
||||||
name="BatchnormActivationDropout_0")[-1]
|
name="BatchnormActivationDropout_0")[-1]
|
||||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||||
|
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ class Head(object):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||||
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
|
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
|
||||||
dense_layer=Dense):
|
dense_layer=Dense, is_training=False):
|
||||||
self.head_idx = head_idx
|
self.head_idx = head_idx
|
||||||
self.network_name = network_name
|
self.network_name = network_name
|
||||||
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
|
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
|
||||||
@@ -64,6 +64,7 @@ class Head(object):
|
|||||||
self.dense_layer = Dense
|
self.dense_layer = Dense
|
||||||
else:
|
else:
|
||||||
self.dense_layer = convert_layer_class(self.dense_layer)
|
self.dense_layer = convert_layer_class(self.dense_layer)
|
||||||
|
self.is_training = is_training
|
||||||
|
|
||||||
def __call__(self, input_layer):
|
def __call__(self, input_layer):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from rl_coach.architectures.tensorflow_components import utils
|
|||||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
|
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
|
||||||
layers = [input_layer]
|
layers = [input_layer]
|
||||||
|
|
||||||
|
# Rationale: passing a bool here will mean that batchnorm and or activation will never activate
|
||||||
|
assert not isinstance(is_training, bool)
|
||||||
|
|
||||||
# batchnorm
|
# batchnorm
|
||||||
if batchnorm:
|
if batchnorm:
|
||||||
layers.append(
|
layers.append(
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from typing import Union, List
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||||
from rl_coach.base_parameters import MiddlewareScheme
|
from rl_coach.base_parameters import MiddlewareScheme
|
||||||
from rl_coach.core_types import Middleware_FC_Embedding
|
from rl_coach.core_types import Middleware_FC_Embedding
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||||
from rl_coach.base_parameters import MiddlewareScheme
|
from rl_coach.base_parameters import MiddlewareScheme
|
||||||
from rl_coach.core_types import Middleware_LSTM_Embedding
|
from rl_coach.core_types import Middleware_LSTM_Embedding
|
||||||
|
|||||||
@@ -25,17 +25,20 @@ def test_embedder(reset):
|
|||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
|
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
|
||||||
|
|
||||||
# creating a simple image embedder
|
|
||||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test")
|
|
||||||
|
|
||||||
# make sure the ops where not created yet
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
assert len(tf.get_default_graph().get_operations()) == 0
|
pre_ops = len(tf.get_default_graph().get_operations())
|
||||||
|
# creating a simple image embedder
|
||||||
|
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training)
|
||||||
|
|
||||||
|
# make sure the only the is_training op is creates
|
||||||
|
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
input_ph, output_ph = embedder()
|
input_ph, output_ph = embedder()
|
||||||
|
|
||||||
# make sure that now the ops were created
|
# make sure that now the ops were created
|
||||||
assert len(tf.get_default_graph().get_operations()) > 0
|
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||||
|
|
||||||
# try feeding a batch of one example
|
# try feeding a batch of one example
|
||||||
input = np.random.rand(1, 100, 100, 10)
|
input = np.random.rand(1, 100, 100, 10)
|
||||||
@@ -55,7 +58,9 @@ def test_embedder(reset):
|
|||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
def test_complex_embedder(reset):
|
def test_complex_embedder(reset):
|
||||||
# creating a deep vector embedder
|
# creating a deep vector embedder
|
||||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep)
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep,
|
||||||
|
is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder()
|
embedder()
|
||||||
@@ -71,8 +76,9 @@ def test_complex_embedder(reset):
|
|||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
def test_activation_function(reset):
|
def test_activation_function(reset):
|
||||||
# creating a deep image embedder with relu
|
# creating a deep image embedder with relu
|
||||||
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
|
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||||
activation_function=tf.nn.relu)
|
activation_function=tf.nn.relu, is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder()
|
embedder()
|
||||||
@@ -86,7 +92,7 @@ def test_activation_function(reset):
|
|||||||
|
|
||||||
# creating a deep image embedder with tanh
|
# creating a deep image embedder with tanh
|
||||||
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
|
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||||
activation_function=tf.nn.tanh)
|
activation_function=tf.nn.tanh, is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder_tanh()
|
embedder_tanh()
|
||||||
|
|||||||
@@ -22,16 +22,19 @@ def test_embedder(reset):
|
|||||||
embedder = VectorEmbedder(np.array([10, 10]), name="test")
|
embedder = VectorEmbedder(np.array([10, 10]), name="test")
|
||||||
|
|
||||||
# creating a simple vector embedder
|
# creating a simple vector embedder
|
||||||
embedder = VectorEmbedder(np.array([10]), name="test")
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
pre_ops = len(tf.get_default_graph().get_operations())
|
||||||
|
|
||||||
|
embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training)
|
||||||
|
|
||||||
# make sure the ops where not created yet
|
# make sure the ops where not created yet
|
||||||
assert len(tf.get_default_graph().get_operations()) == 0
|
assert len(tf.get_default_graph().get_operations()) == pre_ops
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
input_ph, output_ph = embedder()
|
input_ph, output_ph = embedder()
|
||||||
|
|
||||||
# make sure that now the ops were created
|
# make sure that now the ops were created
|
||||||
assert len(tf.get_default_graph().get_operations()) > 0
|
assert len(tf.get_default_graph().get_operations()) > pre_ops
|
||||||
|
|
||||||
# try feeding a batch of one example
|
# try feeding a batch of one example
|
||||||
input = np.random.rand(1, 10)
|
input = np.random.rand(1, 10)
|
||||||
@@ -51,7 +54,8 @@ def test_embedder(reset):
|
|||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
def test_complex_embedder(reset):
|
def test_complex_embedder(reset):
|
||||||
# creating a deep vector embedder
|
# creating a deep vector embedder
|
||||||
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep)
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
|
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder()
|
embedder()
|
||||||
@@ -67,8 +71,9 @@ def test_complex_embedder(reset):
|
|||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
def test_activation_function(reset):
|
def test_activation_function(reset):
|
||||||
# creating a deep vector embedder with relu
|
# creating a deep vector embedder with relu
|
||||||
|
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
|
||||||
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
|
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
|
||||||
activation_function=tf.nn.relu)
|
activation_function=tf.nn.relu, is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder()
|
embedder()
|
||||||
@@ -82,7 +87,7 @@ def test_activation_function(reset):
|
|||||||
|
|
||||||
# creating a deep vector embedder with tanh
|
# creating a deep vector embedder with tanh
|
||||||
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
|
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
|
||||||
activation_function=tf.nn.tanh)
|
activation_function=tf.nn.tanh, is_training=is_training)
|
||||||
|
|
||||||
# call the embedder
|
# call the embedder
|
||||||
embedder_tanh()
|
embedder_tanh()
|
||||||
|
|||||||