1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -270,8 +270,11 @@ class TensorFlowArchitecture(Architecture):
elif self.network_is_trainable:
# not any of the above but is trainable? -> create an operation for applying the gradients to
# this network weights
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope=self.full_name)
with tf.control_dependencies(update_ops):
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.weights), global_step=self.global_step)
def set_session(self, sess):
self.sess = sess
@@ -414,13 +417,16 @@ class TensorFlowArchitecture(Architecture):
return feed_dict
def apply_and_reset_gradients(self, gradients, scaler=1.):
def apply_and_reset_gradients(self, gradients, scaler=1., additional_inputs=None):
"""
Applies the given gradients to the network weights and resets the accumulation placeholder
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""
self.apply_gradients(gradients, scaler)
self.apply_gradients(gradients, scaler, additional_inputs=additional_inputs)
self.reset_accumulated_gradients()
def wait_for_all_workers_to_lock(self, lock: str, include_only_training_workers: bool=False):
@@ -460,13 +466,16 @@ class TensorFlowArchitecture(Architecture):
self.wait_for_all_workers_to_lock('release', include_only_training_workers=include_only_training_workers)
self.sess.run(self.release_init)
def apply_gradients(self, gradients, scaler=1.):
def apply_gradients(self, gradients, scaler=1., additional_inputs=None):
"""
Applies the given gradients to the network weights
:param gradients: The gradients to use for the update
:param scaler: A scaling factor that allows rescaling the gradients before applying them.
The gradients will be MULTIPLIED by this factor
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""
if self.network_parameters.async_training or not isinstance(self.ap.task_parameters, DistributedTaskParameters):
if hasattr(self, 'global_step') and not self.network_is_local:
self.sess.run(self.inc_step)
@@ -503,6 +512,8 @@ class TensorFlowArchitecture(Architecture):
# async distributed training / distributed training with independent optimizer
# / non-distributed training - just apply the gradients
feed_dict = dict(zip(self.weights_placeholders, gradients))
if additional_inputs is not None:
feed_dict = {**feed_dict, **self.create_feed_dict(additional_inputs)}
self.sess.run(self.update_weights_from_batch_gradients, feed_dict=feed_dict)
# release barrier