mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
bug fixes for clippedppo and checkpoints
This commit is contained in:
@@ -117,6 +117,12 @@ class Agent(object):
|
||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||
self.running_reward_stats = RunningStat(())
|
||||
if self.tp.checkpoint_restore_dir:
|
||||
checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p")
|
||||
self.running_observation_stats = read_pickle(checkpoint_path)
|
||||
else:
|
||||
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||
self.running_reward_stats = RunningStat(())
|
||||
else:
|
||||
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
||||
shape=(self.tp.env.desired_observation_width,),
|
||||
@@ -247,7 +253,7 @@ class Agent(object):
|
||||
|
||||
return observation.astype('uint8')
|
||||
else:
|
||||
if self.tp.env.normalize_observation:
|
||||
if self.tp.env.normalize_observation and self.sess is not None:
|
||||
# standardize the input observation using a running mean and std
|
||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||
self.running_observation_stats.push(observation)
|
||||
@@ -544,6 +550,9 @@ class Agent(object):
|
||||
if current_snapshot_period > model_snapshots_periods_passed:
|
||||
model_snapshots_periods_passed = current_snapshot_period
|
||||
self.save_model(model_snapshots_periods_passed)
|
||||
to_pickle(self.running_observation_stats,
|
||||
os.path.join(self.tp.save_model_dir,
|
||||
"running_stats.p".format(model_snapshots_periods_passed)))
|
||||
|
||||
# play and record in replay buffer
|
||||
if self.tp.agent.collect_new_data:
|
||||
|
||||
@@ -176,7 +176,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
||||
dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
|
||||
|
||||
if self.tp.distributed and self.tp.agent.share_statistics_between_workers:
|
||||
self.running_observation_stats.push(np.array([t.state['observation'] for t in dataset]))
|
||||
self.running_observation_stats.push(np.array([np.array(t.state['observation']) for t in dataset]))
|
||||
|
||||
losses = self.train_network(dataset, 10)
|
||||
self.value_loss.add_sample(losses[0])
|
||||
|
||||
@@ -125,9 +125,10 @@ class TensorFlowArchitecture(Architecture):
|
||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
||||
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
if self.tp.visualization.tensorboard:
|
||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||
scope=tf.contrib.framework.get_name_scope())
|
||||
self.merged = tf.summary.merge(current_scope_summaries)
|
||||
|
||||
# initialize or restore model
|
||||
if not self.tp.distributed:
|
||||
@@ -197,7 +198,8 @@ class TensorFlowArchitecture(Architecture):
|
||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||
|
||||
fetches += [self.merged]
|
||||
if self.tp.visualization.tensorboard:
|
||||
fetches += [self.merged]
|
||||
|
||||
# get grads
|
||||
result = self.tp.sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
Reference in New Issue
Block a user