mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -270,8 +270,11 @@ class TensorFlowArchitecture(Architecture):
|
||||
elif self.network_is_trainable:
|
||||
# not any of the above but is trainable? -> create an operation for applying the gradients to
|
||||
# this network weights
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
|
||||
|
||||
with tf.control_dependencies(update_ops):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
|
||||
|
||||
def set_session(self, sess):
|
||||
self.sess = sess
|
||||
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
|
||||
return feed_dict
|
||||
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1.):
|
||||
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights and resets the accumulation placeholder
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
"""
|
||||
self.apply_gradients(gradients, scaler)
|
||||
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
|
||||
self.reset_accumulated_gradients()
|
||||
|
||||
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
|
||||
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
|
||||
self.sess.run(self.release_init)
|
||||
|
||||
def apply_gradients(self, gradients, scaler=1.):
|
||||
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
|
||||
"""
|
||||
Applies the given gradients to the network weights
|
||||
:param gradients: The gradients to use for the update
|
||||
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
|
||||
The gradients will be MULTIPLIED by this factor
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
"""
|
||||
|
||||
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
|
||||
if hasattr(self, 'global_step') and not self.network_is_local:
|
||||
self.sess.run(self.inc_step)
|
||||
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
# async distributed training / distributed training with independent optimizer
|
||||
# / non-distributed training - just apply the gradients
|
||||
feed_dict = dict(zip(self.weights_placeholders, gradients))
|
||||
if additional_inputs is not None:
|
||||
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
|
||||
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
|
||||
|
||||
# release barrier
|
||||
|
||||
Reference in New Issue
Block a user