1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -38,6 +38,7 @@ The environments that were used for testing include:
|**[Clipped PPO](clipped_ppo)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[Clipped PPO](clipped_ppo)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[DDPG](ddpg)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[DDPG](ddpg)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[SAC](sac)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | | |**[SAC](sac)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[TD3](td3)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Mujoco | |
|**[NEC](nec)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | | |**[NEC](nec)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Atari | |
|**[HER](ddpg_her)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Fetch | | |**[HER](ddpg_her)** | ![#2E8B57](https://placehold.it/15/2E8B57/000000?text=+) |Fetch | |
|**[DFP](dfp)** | ![#ceffad](https://placehold.it/15/ceffad/000000?text=+) |Doom | Doom Battle was not verified | |**[DFP](dfp)** | ![#ceffad](https://placehold.it/15/ceffad/000000?text=+) |Doom | Doom Battle was not verified |

Binary file not shown.

Before

Width:  |  Height:  |  Size: 135 KiB

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 89 KiB

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 111 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 113 KiB

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 104 KiB

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 127 KiB

After

Width:  |  Height:  |  Size: 135 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 119 KiB

After

Width:  |  Height:  |  Size: 115 KiB

View File

@@ -34,9 +34,9 @@ from rl_coach.spaces import BoxActionSpace, GoalsSpace
class DDPGCriticNetworkParameters(NetworkParameters): class DDPGCriticNetworkParameters(NetworkParameters):
def __init__(self): def __init__(self, use_batchnorm=False):
super().__init__() super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True), self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)} 'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
self.middleware_parameters = FCMiddlewareParameters() self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DDPGVHeadParameters()] self.heads_parameters = [DDPGVHeadParameters()]
@@ -53,11 +53,11 @@ class DDPGCriticNetworkParameters(NetworkParameters):
class DDPGActorNetworkParameters(NetworkParameters): class DDPGActorNetworkParameters(NetworkParameters):
def __init__(self): def __init__(self, use_batchnorm=False):
super().__init__() super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)} self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True) self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
self.heads_parameters = [DDPGActorHeadParameters()] self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
self.optimizer_type = 'Adam' self.optimizer_type = 'Adam'
self.batch_size = 64 self.batch_size = 64
self.adam_optimizer_beta2 = 0.999 self.adam_optimizer_beta2 = 0.999
@@ -109,12 +109,12 @@ class DDPGAlgorithmParameters(AlgorithmParameters):
class DDPGAgentParameters(AgentParameters): class DDPGAgentParameters(AgentParameters):
def __init__(self): def __init__(self, use_batchnorm=False):
super().__init__(algorithm=DDPGAlgorithmParameters(), super().__init__(algorithm=DDPGAlgorithmParameters(),
exploration=OUProcessParameters(), exploration=OUProcessParameters(),
memory=EpisodicExperienceReplayParameters(), memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict([("actor", DDPGActorNetworkParameters()), networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
("critic", DDPGCriticNetworkParameters())])) ("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
@property @property
def path(self): def path(self):
@@ -170,7 +170,9 @@ class DDPGAgent(ActorCriticAgent):
# train the critic # train the critic
critic_inputs = copy.copy(batch.states(critic_keys)) critic_inputs = copy.copy(batch.states(critic_keys))
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1) critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
total_loss, losses, unclipped_grads = result[:3] total_loss, losses, unclipped_grads = result[:3]
# apply the gradients from the critic to the actor # apply the gradients from the critic to the actor
@@ -179,11 +181,12 @@ class DDPGAgent(ActorCriticAgent):
outputs=actor.online_network.weighted_gradients[0], outputs=actor.online_network.weighted_gradients[0],
initial_feed_dict=initial_feed_dict) initial_feed_dict=initial_feed_dict)
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
if actor.has_global: if actor.has_global:
actor.apply_gradients_to_global_network(gradients) actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
actor.update_online_network() actor.update_online_network()
else: else:
actor.apply_gradients_to_online_network(gradients) actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
return total_loss, losses, unclipped_grads return total_loss, losses, unclipped_grads

View File

@@ -22,7 +22,7 @@ from rl_coach.base_parameters import NetworkComponentParameters
class HeadParameters(NetworkComponentParameters): class HeadParameters(NetworkComponentParameters):
def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head', def __init__(self, parameterized_class_name: str, activation_function: str = 'relu', name: str= 'head',
num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0, num_output_head_copies: int=1, rescale_gradient_from_head_by_factor: float=1.0,
loss_weight: float=1.0, dense_layer=None): loss_weight: float=1.0, dense_layer=None, is_training=False):
super().__init__(dense_layer=dense_layer) super().__init__(dense_layer=dense_layer)
self.activation_function = activation_function self.activation_function = activation_function
self.name = name self.name = name
@@ -30,6 +30,7 @@ class HeadParameters(NetworkComponentParameters):
self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor self.rescale_gradient_from_head_by_factor = rescale_gradient_from_head_by_factor
self.loss_weight = loss_weight self.loss_weight = loss_weight
self.parameterized_class_name = parameterized_class_name self.parameterized_class_name = parameterized_class_name
self.is_training = is_training
@property @property
def path(self): def path(self):

View File

@@ -124,31 +124,37 @@ class NetworkWrapper(object):
if self.global_network: if self.global_network:
self.online_network.set_weights(self.global_network.get_weights(), rate) self.online_network.set_weights(self.global_network.get_weights(), rate)
def apply_gradients_to_global_network(self, gradients=None): def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
""" """
Apply gradients from the online network on the global network Apply gradients from the online network on the global network
:param gradients: optional gradients that will be used instead of teh accumulated gradients :param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return: :return:
""" """
if gradients is None: if gradients is None:
gradients = self.online_network.accumulated_gradients gradients = self.online_network.accumulated_gradients
if self.network_parameters.shared_optimizer: if self.network_parameters.shared_optimizer:
self.global_network.apply_gradients(gradients) self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
else: else:
self.online_network.apply_gradients(gradients) self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
def apply_gradients_to_online_network(self, gradients=None): def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
""" """
Apply gradients from the online network on itself Apply gradients from the online network on itself
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return: :return:
""" """
if gradients is None: if gradients is None:
gradients = self.online_network.accumulated_gradients gradients = self.online_network.accumulated_gradients
self.online_network.apply_gradients(gradients) self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None): def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
use_inputs_for_apply_gradients=False):
""" """
A generic training function that enables multi-threading training using a global network if necessary. A generic training function that enables multi-threading training using a global network if necessary.
@@ -157,14 +163,20 @@ class NetworkWrapper(object):
:param additional_fetches: Any additional tensor the user wants to fetch :param additional_fetches: Any additional tensor the user wants to fetch
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss :param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
error of this sample. If it is not given, the samples losses won't be scaled error of this sample. If it is not given, the samples losses won't be scaled
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
(e.g. for incorporating batchnorm update ops)
:return: The loss of the training iteration :return: The loss of the training iteration
""" """
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches, result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
importance_weights=importance_weights, no_accumulation=True) importance_weights=importance_weights, no_accumulation=True)
if use_inputs_for_apply_gradients:
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
else:
self.apply_gradients_and_sync_networks(reset_gradients=False) self.apply_gradients_and_sync_networks(reset_gradients=False)
return result return result
def apply_gradients_and_sync_networks(self, reset_gradients=True): def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
""" """
Applies the gradients accumulated in the online network to the global network or to itself and syncs the Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary networks if necessary
@@ -173,17 +185,22 @@ class NetworkWrapper(object):
the network. this is useful when the accumulated gradients are overwritten instead the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time if accumulated by the accumulate_gradients function. this allows reducing time
complexity for this function by around 10% complexity for this function by around 10%
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
""" """
if self.global_network: if self.global_network:
self.apply_gradients_to_global_network() self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
if reset_gradients: if reset_gradients:
self.online_network.reset_accumulated_gradients() self.online_network.reset_accumulated_gradients()
self.update_online_network() self.update_online_network()
else: else:
if reset_gradients: if reset_gradients:
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients) self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)
else: else:
self.online_network.apply_gradients(self.online_network.accumulated_gradients) self.online_network.apply_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)
def parallel_prediction(self, network_input_tuples: List[Tuple]): def parallel_prediction(self, network_input_tuples: List[Tuple]):
""" """

View File

@@ -270,6 +270,9 @@ class TensorFlowArchitecture(Architecture):
elif self.network_is_trainable: elif self.network_is_trainable:
# not any of the above but is trainable? -> create an operation for applying the gradients to # not any of the above but is trainable? -> create an operation for applying the gradients to
# this network weights # this network weights
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
with tf.control_dependencies(update_ops):
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.weights), global_step=self.global_step) zip(self.weights_placeholders, self.weights), global_step=self.global_step)
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
return feed_dict return feed_dict
def apply_and_reset_gradients(self, gradients, scaler=1.): def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
""" """
Applies the given gradients to the network weights and resets the accumulation placeholder Applies the given gradients to the network weights and resets the accumulation placeholder
:param gradients: The gradients to use for the update :param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them :param scaler: A scaling factor that allows rescaling the gradients before applying them
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
""" """
self.apply_gradients(gradients, scaler) self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
self.reset_accumulated_gradients() self.reset_accumulated_gradients()
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False): def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers) self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
self.sess.run(self.release_init) self.sess.run(self.release_init)
def apply_gradients(self, gradients, scaler=1.): def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
""" """
Applies the given gradients to the network weights Applies the given gradients to the network weights
:param gradients: The gradients to use for the update :param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them. :param scaler: A scaling factor that allows rescaling the gradients before applying them.
The gradients will be MULTIPLIED by this factor The gradients will be MULTIPLIED by this factor
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
""" """
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters): if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
if hasattr(self, 'global_step') and not self.network_is_local: if hasattr(self, 'global_step') and not self.network_is_local:
self.sess.run(self.inc_step) self.sess.run(self.inc_step)
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
# async distributed training / distributed training with independent optimizer # async distributed training / distributed training with independent optimizer
# / non-distributed training - just apply the gradients # / non-distributed training - just apply the gradients
feed_dict = dict(zip(self.weights_placeholders, gradients)) feed_dict = dict(zip(self.weights_placeholders, gradients))
if additional_inputs is not None:
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict) self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
# release barrier # release barrier

View File

@@ -185,6 +185,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
embedder_path = embedder_params.path(emb_type) embedder_path = embedder_params.path(emb_type)
embedder_params_copy = copy.copy(embedder_params) embedder_params_copy = copy.copy(embedder_params)
embedder_params_copy.is_training = self.is_training
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function) embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type] embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type] embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
@@ -204,6 +205,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
middleware_path = middleware_params.path middleware_path = middleware_params.path
middleware_params_copy = copy.copy(middleware_params) middleware_params_copy = copy.copy(middleware_params)
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function) middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
middleware_params_copy.is_training = self.is_training
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path) module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
return module return module
@@ -218,6 +220,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
head_path = head_params.path head_path = head_params.path
head_params_copy = copy.copy(head_params) head_params_copy = copy.copy(head_params)
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function) head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
head_params_copy.is_training = self.is_training
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={ return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name, 'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
'head_idx': head_idx, 'is_local': self.network_is_local}) 'head_idx': head_idx, 'is_local': self.network_is_local})
@@ -339,7 +342,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
head_count += 1 head_count += 1
# model weights # model weights
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name) if not self.distributed_training or self.network_is_global:
self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if
'global_step' not in var.name]
else:
self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)]
# Losses # Losses
self.losses = tf.losses.get_losses(self.full_name) self.losses = tf.losses.get_losses(self.full_name)

View File

@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
class DDPGActor(Head): class DDPGActor(Head):
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh', head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
batchnorm: bool=True, dense_layer=Dense): batchnorm: bool=True, dense_layer=Dense, is_training=False):
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function, super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
dense_layer=dense_layer) dense_layer=dense_layer, is_training=is_training)
self.name = 'ddpg_actor_head' self.name = 'ddpg_actor_head'
self.return_type = ActionProbabilities self.return_type = ActionProbabilities
@@ -50,7 +50,7 @@ class DDPGActor(Head):
batchnorm=self.batchnorm, batchnorm=self.batchnorm,
activation_function=self.activation_function, activation_function=self.activation_function,
dropout_rate=0, dropout_rate=0,
is_training=False, is_training=self.is_training,
name="BatchnormActivationDropout_0")[-1] name="BatchnormActivationDropout_0")[-1]
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean') self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')

View File

@@ -40,7 +40,7 @@ class Head(object):
""" """
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str, def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu', head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
dense_layer=Dense): dense_layer=Dense, is_training=False):
self.head_idx = head_idx self.head_idx = head_idx
self.network_name = network_name self.network_name = network_name
self.network_parameters = agent_parameters.network_wrappers[self.network_name] self.network_parameters = agent_parameters.network_wrappers[self.network_name]
@@ -64,6 +64,7 @@ class Head(object):
self.dense_layer = Dense self.dense_layer = Dense
else: else:
self.dense_layer = convert_layer_class(self.dense_layer) self.dense_layer = convert_layer_class(self.dense_layer)
self.is_training = is_training
def __call__(self, input_layer): def __call__(self, input_layer):
""" """

View File

@@ -26,6 +26,9 @@ from rl_coach.architectures.tensorflow_components import utils
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name): def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
layers = [input_layer] layers = [input_layer]
# Rationale: passing a bool here will mean that batchnorm and or activation will never activate
assert not isinstance(is_training, bool)
# batchnorm # batchnorm
if batchnorm: if batchnorm:
layers.append( layers.append(

View File

@@ -17,7 +17,7 @@ from typing import Union, List
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
from rl_coach.base_parameters import MiddlewareScheme from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.core_types import Middleware_FC_Embedding from rl_coach.core_types import Middleware_FC_Embedding

View File

@@ -18,7 +18,7 @@
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense from rl_coach.architectures.tensorflow_components.layers import Dense
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
from rl_coach.base_parameters import MiddlewareScheme from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.core_types import Middleware_LSTM_Embedding from rl_coach.core_types import Middleware_LSTM_Embedding

View File

@@ -25,17 +25,20 @@ def test_embedder(reset):
with pytest.raises(ValueError): with pytest.raises(ValueError):
embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test") embedder = ImageEmbedder(np.array([10, 100, 100, 100]), name="test")
# creating a simple image embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test")
# make sure the ops where not created yet is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
assert len(tf.get_default_graph().get_operations()) == 0 pre_ops = len(tf.get_default_graph().get_operations())
# creating a simple image embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", is_training=is_training)
# make sure the only the is_training op is creates
assert len(tf.get_default_graph().get_operations()) == pre_ops
# call the embedder # call the embedder
input_ph, output_ph = embedder() input_ph, output_ph = embedder()
# make sure that now the ops were created # make sure that now the ops were created
assert len(tf.get_default_graph().get_operations()) > 0 assert len(tf.get_default_graph().get_operations()) > pre_ops
# try feeding a batch of one example # try feeding a batch of one example
input = np.random.rand(1, 100, 100, 10) input = np.random.rand(1, 100, 100, 10)
@@ -55,7 +58,9 @@ def test_embedder(reset):
@pytest.mark.unit_test @pytest.mark.unit_test
def test_complex_embedder(reset): def test_complex_embedder(reset):
# creating a deep vector embedder # creating a deep vector embedder
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep) is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = ImageEmbedder(np.array([100, 100, 10]), name="test", scheme=EmbedderScheme.Deep,
is_training=is_training)
# call the embedder # call the embedder
embedder() embedder()
@@ -71,8 +76,9 @@ def test_complex_embedder(reset):
@pytest.mark.unit_test @pytest.mark.unit_test
def test_activation_function(reset): def test_activation_function(reset):
# creating a deep image embedder with relu # creating a deep image embedder with relu
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep, embedder = ImageEmbedder(np.array([100, 100, 10]), name="relu", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.relu) activation_function=tf.nn.relu, is_training=is_training)
# call the embedder # call the embedder
embedder() embedder()
@@ -86,7 +92,7 @@ def test_activation_function(reset):
# creating a deep image embedder with tanh # creating a deep image embedder with tanh
embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep, embedder_tanh = ImageEmbedder(np.array([100, 100, 10]), name="tanh", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.tanh) activation_function=tf.nn.tanh, is_training=is_training)
# call the embedder # call the embedder
embedder_tanh() embedder_tanh()

View File

@@ -22,16 +22,19 @@ def test_embedder(reset):
embedder = VectorEmbedder(np.array([10, 10]), name="test") embedder = VectorEmbedder(np.array([10, 10]), name="test")
# creating a simple vector embedder # creating a simple vector embedder
embedder = VectorEmbedder(np.array([10]), name="test") is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
pre_ops = len(tf.get_default_graph().get_operations())
embedder = VectorEmbedder(np.array([10]), name="test", is_training=is_training)
# make sure the ops where not created yet # make sure the ops where not created yet
assert len(tf.get_default_graph().get_operations()) == 0 assert len(tf.get_default_graph().get_operations()) == pre_ops
# call the embedder # call the embedder
input_ph, output_ph = embedder() input_ph, output_ph = embedder()
# make sure that now the ops were created # make sure that now the ops were created
assert len(tf.get_default_graph().get_operations()) > 0 assert len(tf.get_default_graph().get_operations()) > pre_ops
# try feeding a batch of one example # try feeding a batch of one example
input = np.random.rand(1, 10) input = np.random.rand(1, 10)
@@ -51,7 +54,8 @@ def test_embedder(reset):
@pytest.mark.unit_test @pytest.mark.unit_test
def test_complex_embedder(reset): def test_complex_embedder(reset):
# creating a deep vector embedder # creating a deep vector embedder
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep) is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = VectorEmbedder(np.array([10]), name="test", scheme=EmbedderScheme.Deep, is_training=is_training)
# call the embedder # call the embedder
embedder() embedder()
@@ -67,8 +71,9 @@ def test_complex_embedder(reset):
@pytest.mark.unit_test @pytest.mark.unit_test
def test_activation_function(reset): def test_activation_function(reset):
# creating a deep vector embedder with relu # creating a deep vector embedder with relu
is_training = tf.Variable(False, trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES])
embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep, embedder = VectorEmbedder(np.array([10]), name="relu", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.relu) activation_function=tf.nn.relu, is_training=is_training)
# call the embedder # call the embedder
embedder() embedder()
@@ -82,7 +87,7 @@ def test_activation_function(reset):
# creating a deep vector embedder with tanh # creating a deep vector embedder with tanh
embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep, embedder_tanh = VectorEmbedder(np.array([10]), name="tanh", scheme=EmbedderScheme.Deep,
activation_function=tf.nn.tanh) activation_function=tf.nn.tanh, is_training=is_training)
# call the embedder # call the embedder
embedder_tanh() embedder_tanh()