mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
batchnorm fixes + disabling batchnorm in DDPG (#353)
Co-authored-by: James Casbon <casbon+gh@gmail.com>
This commit is contained in:
@@ -34,9 +34,9 @@ from rl_coach.spaces import BoxActionSpace, GoalsSpace
|
||||
|
||||
|
||||
class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True),
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm),
|
||||
'action': InputEmbedderParameters(scheme=EmbedderScheme.Shallow)}
|
||||
self.middleware_parameters = FCMiddlewareParameters()
|
||||
self.heads_parameters = [DDPGVHeadParameters()]
|
||||
@@ -53,11 +53,11 @@ class DDPGCriticNetworkParameters(NetworkParameters):
|
||||
|
||||
|
||||
class DDPGActorNetworkParameters(NetworkParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__()
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=True)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=True)
|
||||
self.heads_parameters = [DDPGActorHeadParameters()]
|
||||
self.input_embedders_parameters = {'observation': InputEmbedderParameters(batchnorm=use_batchnorm)}
|
||||
self.middleware_parameters = FCMiddlewareParameters(batchnorm=use_batchnorm)
|
||||
self.heads_parameters = [DDPGActorHeadParameters(batchnorm=use_batchnorm)]
|
||||
self.optimizer_type = 'Adam'
|
||||
self.batch_size = 64
|
||||
self.adam_optimizer_beta2 = 0.999
|
||||
@@ -109,12 +109,12 @@ class DDPGAlgorithmParameters(AlgorithmParameters):
|
||||
|
||||
|
||||
class DDPGAgentParameters(AgentParameters):
|
||||
def __init__(self):
|
||||
def __init__(self, use_batchnorm=False):
|
||||
super().__init__(algorithm=DDPGAlgorithmParameters(),
|
||||
exploration=OUProcessParameters(),
|
||||
memory=EpisodicExperienceReplayParameters(),
|
||||
networks=OrderedDict([("actor", DDPGActorNetworkParameters()),
|
||||
("critic", DDPGCriticNetworkParameters())]))
|
||||
networks=OrderedDict([("actor", DDPGActorNetworkParameters(use_batchnorm=use_batchnorm)),
|
||||
("critic", DDPGCriticNetworkParameters(use_batchnorm=use_batchnorm))]))
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
@@ -170,7 +170,9 @@ class DDPGAgent(ActorCriticAgent):
|
||||
# train the critic
|
||||
critic_inputs = copy.copy(batch.states(critic_keys))
|
||||
critic_inputs['action'] = batch.actions(len(batch.actions().shape) == 1)
|
||||
result = critic.train_and_sync_networks(critic_inputs, TD_targets)
|
||||
|
||||
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||
result = critic.train_and_sync_networks(critic_inputs, TD_targets, use_inputs_for_apply_gradients=True)
|
||||
total_loss, losses, unclipped_grads = result[:3]
|
||||
|
||||
# apply the gradients from the critic to the actor
|
||||
@@ -179,11 +181,12 @@ class DDPGAgent(ActorCriticAgent):
|
||||
outputs=actor.online_network.weighted_gradients[0],
|
||||
initial_feed_dict=initial_feed_dict)
|
||||
|
||||
# also need the inputs for when applying gradients so batchnorm's update of running mean and stddev will work
|
||||
if actor.has_global:
|
||||
actor.apply_gradients_to_global_network(gradients)
|
||||
actor.apply_gradients_to_global_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||
actor.update_online_network()
|
||||
else:
|
||||
actor.apply_gradients_to_online_network(gradients)
|
||||
actor.apply_gradients_to_online_network(gradients, additional_inputs=copy.copy(batch.states(critic_keys)))
|
||||
|
||||
return total_loss, losses, unclipped_grads
|
||||
|
||||
|
||||
Reference in New Issue
Block a user