1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

bug fixes for clippedppo and checkpoints

This commit is contained in:
Gal Novik
2018-04-30 15:13:29 +03:00
parent f31159aad6
commit dafdb05a7c
3 changed files with 17 additions and 6 deletions

View File

@@ -117,6 +117,12 @@ class Agent(object):
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
self.running_reward_stats = RunningStat(())
if self.tp.checkpoint_restore_dir:
checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p")
self.running_observation_stats = read_pickle(checkpoint_path)
else:
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
self.running_reward_stats = RunningStat(())
else:
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
shape=(self.tp.env.desired_observation_width,),
@@ -247,7 +253,7 @@ class Agent(object):
return observation.astype('uint8')
else:
if self.tp.env.normalize_observation:
if self.tp.env.normalize_observation and self.sess is not None:
# standardize the input observation using a running mean and std
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
self.running_observation_stats.push(observation)
@@ -544,6 +550,9 @@ class Agent(object):
if current_snapshot_period > model_snapshots_periods_passed:
model_snapshots_periods_passed = current_snapshot_period
self.save_model(model_snapshots_periods_passed)
to_pickle(self.running_observation_stats,
os.path.join(self.tp.save_model_dir,
"running_stats.p".format(model_snapshots_periods_passed)))
# play and record in replay buffer
if self.tp.agent.collect_new_data:

View File

@@ -176,7 +176,7 @@ class ClippedPPOAgent(ActorCriticAgent):
dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
if self.tp.distributed and self.tp.agent.share_statistics_between_workers:
self.running_observation_stats.push(np.array([t.state['observation'] for t in dataset]))
self.running_observation_stats.push(np.array([np.array(t.state['observation']) for t in dataset]))
losses = self.train_network(dataset, 10)
self.value_loss.add_sample(losses[0])

View File

@@ -125,6 +125,7 @@ class TensorFlowArchitecture(Architecture):
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
if self.tp.visualization.tensorboard:
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
scope=tf.contrib.framework.get_name_scope())
self.merged = tf.summary.merge(current_scope_summaries)
@@ -197,6 +198,7 @@ class TensorFlowArchitecture(Architecture):
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
if self.tp.visualization.tensorboard:
fetches += [self.merged]
# get grads