mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -124,31 +124,37 @@ class NetworkWrapper(object):
|
||||
if self.global_network:
|
||||
self.online_network.set_weights(self.global_network.get_weights(), rate)
|
||||
|
||||
def apply_gradients_to_global_network(self, gradients=None):
|
||||
def apply_gradients_to_global_network(self, gradients=None, additional_inputs=None):
|
||||
"""
|
||||
Apply gradients from the online network on the global network
|
||||
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
gradients = self.online_network.accumulated_gradients
|
||||
if self.network_parameters.shared_optimizer:
|
||||
self.global_network.apply_gradients(gradients)
|
||||
self.global_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
else:
|
||||
self.online_network.apply_gradients(gradients)
|
||||
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
|
||||
def apply_gradients_to_online_network(self, gradients=None):
|
||||
def apply_gradients_to_online_network(self, gradients=None, additional_inputs=None):
|
||||
"""
|
||||
Apply gradients from the online network on itself
|
||||
:param gradients: optional gradients that will be used instead of teh accumulated gradients
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
:return:
|
||||
"""
|
||||
if gradients is None:
|
||||
gradients = self.online_network.accumulated_gradients
|
||||
self.online_network.apply_gradients(gradients)
|
||||
self.online_network.apply_gradients(gradients, additional_inputs=additional_inputs)
|
||||
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None):
|
||||
def train_and_sync_networks(self, inputs, targets, additional_fetches=[], importance_weights=None,
|
||||
use_inputs_for_apply_gradients=False):
|
||||
"""
|
||||
A generic training function that enables multi-threading training using a global network if necessary.
|
||||
|
||||
@@ -157,14 +163,20 @@ class NetworkWrapper(object):
|
||||
:param additional_fetches: Any additional tensor the user wants to fetch
|
||||
:param importance_weights: A coefficient for each sample in the batch, which will be used to rescale the loss
|
||||
error of this sample. If it is not given, the samples losses won't be scaled
|
||||
:param use_inputs_for_apply_gradients: Add the inputs also for when applying gradients
|
||||
(e.g. for incorporating batchnorm update ops)
|
||||
:return: The loss of the training iteration
|
||||
"""
|
||||
result = self.online_network.accumulate_gradients(inputs, targets, additional_fetches=additional_fetches,
|
||||
importance_weights=importance_weights, no_accumulation=True)
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
||||
if use_inputs_for_apply_gradients:
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False, additional_inputs=inputs)
|
||||
else:
|
||||
self.apply_gradients_and_sync_networks(reset_gradients=False)
|
||||
|
||||
return result
|
||||
|
||||
def apply_gradients_and_sync_networks(self, reset_gradients=True):
|
||||
def apply_gradients_and_sync_networks(self, reset_gradients=True, additional_inputs=None):
|
||||
"""
|
||||
Applies the gradients accumulated in the online network to the global network or to itself and syncs the
|
||||
networks if necessary
|
||||
@@ -173,17 +185,22 @@ class NetworkWrapper(object):
|
||||
the network. this is useful when the accumulated gradients are overwritten instead
|
||||
if accumulated by the accumulate_gradients function. this allows reducing time
|
||||
complexity for this function by around 10%
|
||||
:param additional_inputs: optional additional inputs required for when applying the gradients (e.g. batchnorm's
|
||||
update ops also requires the inputs)
|
||||
|
||||
"""
|
||||
if self.global_network:
|
||||
self.apply_gradients_to_global_network()
|
||||
self.apply_gradients_to_global_network(additional_inputs=additional_inputs)
|
||||
if reset_gradients:
|
||||
self.online_network.reset_accumulated_gradients()
|
||||
self.update_online_network()
|
||||
else:
|
||||
if reset_gradients:
|
||||
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients)
|
||||
self.online_network.apply_and_reset_gradients(self.online_network.accumulated_gradients,
|
||||
additional_inputs=additional_inputs)
|
||||
else:
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients)
|
||||
self.online_network.apply_gradients(self.online_network.accumulated_gradients,
|
||||
additional_inputs=additional_inputs)
|
||||
|
||||
def parallel_prediction(self, network_input_tuples: List[Tuple]):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user