mirror of
https://github.com/gryf/coach.git
synced 2026-02-15 05:25:55 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -270,8 +270,11 @@ class TensorFlowArchitecture(Architecture):
|
||||
elif self.network_is_trainable:
|
||||
# not any of the above but is trainable? -> create an operation for applying the gradients to
|
||||
# this network weights
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
|
||||
|
||||
with tf.control_dependencies(update_ops):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
|
||||
def set_session(self, sess):
|
||||
self.sess = sess
|
||||
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
return feed_dict
|
||||
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1.):
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets the accumulation placeholder
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
"""
|
||||
self.apply_gradients(gradients, scaler)
|
||||
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
|
||||
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
|
||||
self.sess.run(self.release_init)
|
||||
|
||||
def apply_gradients(self, gradients, scaler=1.):
|
||||
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
||||
The gradients will be MULTIPLIED by this factor
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
"""
|
||||
|
||||
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
if hasattr(self, 'global_step') and not self.network_is_local:
|
||||
self.sess.run(self.inc_step)
|
||||
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
# async distributed training / distributed training with independent optimizer
|
||||
# / non-distributed training - just apply the gradients
|
||||
feed_dict = dict(zip(self.weights_placeholders, gradients))
|
||||
if additional_inputs is not None:
|
||||
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
|
||||
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
|
||||
|
||||
# release barrier
|
||||
|
||||
@@ -185,6 +185,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
|
||||
embedder_path = embedder_params.path(emb_type)
|
||||
embedder_params_copy = copy.copy(embedder_params)
|
||||
embedder_params_copy.is_training = self.is_training
|
||||
embedder_params_copy.activation_function = utils.get_activation_function(embedder_params.activation_function)
|
||||
embedder_params_copy.input_rescaling = embedder_params_copy.input_rescaling[emb_type]
|
||||
embedder_params_copy.input_offset = embedder_params_copy.input_offset[emb_type]
|
||||
@@ -204,6 +205,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
middleware_path = middleware_params.path
|
||||
middleware_params_copy = copy.copy(middleware_params)
|
||||
middleware_params_copy.activation_function = utils.get_activation_function(middleware_params.activation_function)
|
||||
middleware_params_copy.is_training = self.is_training
|
||||
module = dynamic_import_and_instantiate_module_from_params(middleware_params_copy, path=middleware_path)
|
||||
return module
|
||||
|
||||
@@ -218,6 +220,7 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_path = head_params.path
|
||||
head_params_copy = copy.copy(head_params)
|
||||
head_params_copy.activation_function = utils.get_activation_function(head_params_copy.activation_function)
|
||||
head_params_copy.is_training = self.is_training
|
||||
return dynamic_import_and_instantiate_module_from_params(head_params_copy, path=head_path, extra_kwargs={
|
||||
'agent_parameters': self.ap, 'spaces': self.spaces, 'network_name': self.network_wrapper_name,
|
||||
'head_idx': head_idx, 'is_local': self.network_is_local})
|
||||
@@ -339,7 +342,11 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
head_count += 1
|
||||
|
||||
# model weights
|
||||
self.weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)
|
||||
if not self.distributed_training or self.network_is_global:
|
||||
self.weights = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.full_name) if
|
||||
'global_step' not in var.name]
|
||||
else:
|
||||
self.weights = [var for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.full_name)]
|
||||
|
||||
# Losses
|
||||
self.losses = tf.losses.get_losses(self.full_name)
|
||||
|
||||
@@ -26,9 +26,9 @@ from rl_coach.spaces import SpacesDefinition
|
||||
class DDPGActor(Head):
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int = 0, loss_weight: float = 1., is_local: bool = True, activation_function: str='tanh',
|
||||
batchnorm: bool=True, dense_layer=Dense):
|
||||
batchnorm: bool=True, dense_layer=Dense, is_training=False):
|
||||
super().__init__(agent_parameters, spaces, network_name, head_idx, loss_weight, is_local, activation_function,
|
||||
dense_layer=dense_layer)
|
||||
dense_layer=dense_layer, is_training=is_training)
|
||||
self.name = 'ddpg_actor_head'
|
||||
self.return_type = ActionProbabilities
|
||||
|
||||
@@ -50,7 +50,7 @@ class DDPGActor(Head):
|
||||
batchnorm=self.batchnorm,
|
||||
activation_function=self.activation_function,
|
||||
dropout_rate=0,
|
||||
is_training=False,
|
||||
is_training=self.is_training,
|
||||
name="BatchnormActivationDropout_0")[-1]
|
||||
self.policy_mean = tf.multiply(policy_values_mean, self.output_scale, name='output_mean')
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class Head(object):
|
||||
"""
|
||||
def __init__(self, agent_parameters: AgentParameters, spaces: SpacesDefinition, network_name: str,
|
||||
head_idx: int=0, loss_weight: float=1., is_local: bool=True, activation_function: str='relu',
|
||||
dense_layer=Dense):
|
||||
dense_layer=Dense, is_training=False):
|
||||
self.head_idx = head_idx
|
||||
self.network_name = network_name
|
||||
self.network_parameters = agent_parameters.network_wrappers[self.network_name]
|
||||
@@ -64,6 +64,7 @@ class Head(object):
|
||||
self.dense_layer = Dense
|
||||
else:
|
||||
self.dense_layer = convert_layer_class(self.dense_layer)
|
||||
self.is_training = is_training
|
||||
|
||||
def __call__(self, input_layer):
|
||||
"""
|
||||
|
||||
@@ -26,6 +26,9 @@ from rl_coach.architectures.tensorflow_components import utils
|
||||
def batchnorm_activation_dropout(input_layer, batchnorm, activation_function, dropout_rate, is_training, name):
|
||||
layers = [input_layer]
|
||||
|
||||
# Rationale: passing a bool here will mean that batchnorm and or activation will never activate
|
||||
assert not isinstance(is_training, bool)
|
||||
|
||||
# batchnorm
|
||||
if batchnorm:
|
||||
layers.append(
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Union, List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_FC_Embedding
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import batchnorm_activation_dropout, Dense
|
||||
from rl_coach.architectures.tensorflow_components.layers import Dense
|
||||
from rl_coach.architectures.tensorflow_components.middlewares.middleware import Middleware
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.core_types import Middleware_LSTM_Embedding
|
||||
|
||||
Reference in New Issue
Block a user