mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
bug fixes for clippedppo and checkpoints
This commit is contained in:
@@ -117,6 +117,12 @@ class Agent(object):
|
|||||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||||
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||||
self.running_reward_stats = RunningStat(())
|
self.running_reward_stats = RunningStat(())
|
||||||
|
if self.tp.checkpoint_restore_dir:
|
||||||
|
checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p")
|
||||||
|
self.running_observation_stats = read_pickle(checkpoint_path)
|
||||||
|
else:
|
||||||
|
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||||
|
self.running_reward_stats = RunningStat(())
|
||||||
else:
|
else:
|
||||||
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
||||||
shape=(self.tp.env.desired_observation_width,),
|
shape=(self.tp.env.desired_observation_width,),
|
||||||
@@ -247,7 +253,7 @@ class Agent(object):
|
|||||||
|
|
||||||
return observation.astype('uint8')
|
return observation.astype('uint8')
|
||||||
else:
|
else:
|
||||||
if self.tp.env.normalize_observation:
|
if self.tp.env.normalize_observation and self.sess is not None:
|
||||||
# standardize the input observation using a running mean and std
|
# standardize the input observation using a running mean and std
|
||||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||||
self.running_observation_stats.push(observation)
|
self.running_observation_stats.push(observation)
|
||||||
@@ -544,6 +550,9 @@ class Agent(object):
|
|||||||
if current_snapshot_period > model_snapshots_periods_passed:
|
if current_snapshot_period > model_snapshots_periods_passed:
|
||||||
model_snapshots_periods_passed = current_snapshot_period
|
model_snapshots_periods_passed = current_snapshot_period
|
||||||
self.save_model(model_snapshots_periods_passed)
|
self.save_model(model_snapshots_periods_passed)
|
||||||
|
to_pickle(self.running_observation_stats,
|
||||||
|
os.path.join(self.tp.save_model_dir,
|
||||||
|
"running_stats.p".format(model_snapshots_periods_passed)))
|
||||||
|
|
||||||
# play and record in replay buffer
|
# play and record in replay buffer
|
||||||
if self.tp.agent.collect_new_data:
|
if self.tp.agent.collect_new_data:
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ class ClippedPPOAgent(ActorCriticAgent):
|
|||||||
dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
|
dataset = dataset[:self.tp.agent.num_consecutive_playing_steps]
|
||||||
|
|
||||||
if self.tp.distributed and self.tp.agent.share_statistics_between_workers:
|
if self.tp.distributed and self.tp.agent.share_statistics_between_workers:
|
||||||
self.running_observation_stats.push(np.array([t.state['observation'] for t in dataset]))
|
self.running_observation_stats.push(np.array([np.array(t.state['observation']) for t in dataset]))
|
||||||
|
|
||||||
losses = self.train_network(dataset, 10)
|
losses = self.train_network(dataset, 10)
|
||||||
self.value_loss.add_sample(losses[0])
|
self.value_loss.add_sample(losses[0])
|
||||||
|
|||||||
@@ -125,6 +125,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
self.update_weights_from_batch_gradients = self.optimizer.apply_gradients(
|
||||||
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step)
|
||||||
|
|
||||||
|
if self.tp.visualization.tensorboard:
|
||||||
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
|
||||||
scope=tf.contrib.framework.get_name_scope())
|
scope=tf.contrib.framework.get_name_scope())
|
||||||
self.merged = tf.summary.merge(current_scope_summaries)
|
self.merged = tf.summary.merge(current_scope_summaries)
|
||||||
@@ -197,6 +198,7 @@ class TensorFlowArchitecture(Architecture):
|
|||||||
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init
|
||||||
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init
|
||||||
|
|
||||||
|
if self.tp.visualization.tensorboard:
|
||||||
fetches += [self.merged]
|
fetches += [self.merged]
|
||||||
|
|
||||||
# get grads
|
# get grads
|
||||||
|
|||||||
Reference in New Issue
Block a user