1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

batchnorm fixes + disabling batchnorm in DDPG (#353)

Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
Gal Leibovich
2019-06-23 11:28:22 +03:00
committed by GitHub
parent 7b5d6a3f03
commit d6795bd524
22 changed files with 105 additions and 50 deletions

View File

@@ -34,9 +34,9 @@ from rl_coach.spaces import BoxActionSpace, GoalsSpace
class DDPGCriticNetworkParameters(NetworkParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True),
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
self.middleware_parameters = FCMiddlewareParameters()
self.heads_parameters = [DDPGVHeadParameters()]
@@ -53,11 +53,11 @@ class DDPGCriticNetworkParameters(NetworkParameters):
class DDPGActorNetworkParameters(NetworkParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__()
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
self.heads_parameters = [DDPGActorHeadParameters()]
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
self.optimizer_type = 'Adam'
self.batch_size = 64
self.adam_optimizer_beta2 = 0.999
@@ -109,12 +109,12 @@ class DDPGAlgorithmParameters(AlgorithmParameters):
class DDPGAgentParameters(AgentParameters):
def __init__(self):
def __init__(self, use_batchnorm=False):
super().__init__(algorithm=DDPGAlgorithmParameters(),
exploration=OUProcessParameters(),
memory=EpisodicExperienceReplayParameters(),
networks=OrderedDict([("actor", DDPGActorNetworkParameters()),
("critic", DDPGCriticNetworkParameters())]))
networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
@property
def path(self):
@@ -170,7 +170,9 @@ class DDPGAgent(ActorCriticAgent):
# train the critic
critic_inputs = copy.copy(batch.states(critic_keys))
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
total_loss, losses, unclipped_grads = result[:3]
# apply the gradients from the critic to the actor
@@ -179,11 +181,12 @@ class DDPGAgent(ActorCriticAgent):
outputs=actor.online_network.weighted_gradients[0],
initial_feed_dict=initial_feed_dict)
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
if actor.has_global:
actor.apply_gradients_to_global_network(gradients)
actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
actor.update_online_network()
else:
actor.apply_gradients_to_online_network(gradients)
actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
return total_loss, losses, unclipped_grads