1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 03:30:19 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -124,31 +124,37 @@ class NetworkWrapper(object):
if self.global_network:
self.online_network.set_weights(self.global_network.get_weights(), rate)
def apply_gradients_to_global_network(self, gradients=None):
def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
"""
Apply gradients from the online network on the global network
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return:
"""
if gradients is None:
gradients = self.online_network.accumulated_gradients
if self.network_parameters.shared_optimizer:
self.global_network.apply_gradients(gradients)
self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
else:
self.online_network.apply_gradients(gradients)
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
def apply_gradients_to_online_network(self, gradients=None):
def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
"""
Apply gradients from the online network on itself
:param gradients: optional gradients that will be used instead of teh accumulated gradients
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
:return:
"""
if gradients is None:
gradients = self.online_network.accumulated_gradients
self.online_network.apply_gradients(gradients)
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
use_inputs_for_apply_gradients=False):
"""
A generic training function that enables multi-threading training using a global network if necessary.
@@ -157,14 +163,20 @@ class NetworkWrapper(object):
:param additional_fetches: Any additional tensor the user wants to fetch
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
error of this sample. If it is not given, the samples losses won't be scaled
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
(e.g. for incorporating batchnorm update ops)
:return: The loss of the training iteration
"""
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
importance_weights=importance_weights, no_accumulation=True)
self.apply_gradients_and_sync_networks(reset_gradients=False)
if use_inputs_for_apply_gradients:
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
else:
self.apply_gradients_and_sync_networks(reset_gradients=False)
return result
def apply_gradients_and_sync_networks(self, reset_gradients=True):
def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
"""
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
networks if necessary
@@ -173,17 +185,22 @@ class NetworkWrapper(object):
the network. this is useful when the accumulated gradients are overwritten instead
if accumulated by the accumulate_gradients function. this allows reducing time
complexity for this function by around 10%
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
update ops also requires the inputs)
"""
if self.global_network:
self.apply_gradients_to_global_network()
self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
if reset_gradients:
self.online_network.reset_accumulated_gradients()
self.update_online_network()
else:
if reset_gradients:
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)
else:
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
self.online_network.apply_gradients(self.online_network.accumulated_gradients,
additional_inputs=additional_inputs)
def parallel_prediction(self, network_input_tuples: List[Tuple]):
"""