From dafdb05a7cf52e401101f68b551dc6a5ed0f5a36 Mon Sep 17 00:00:00 2001 From: Gal Novik Date: Mon, 30 Apr 2018 15:13:29 +0300 Subject: [PATCH] bug fixes for clippedppo and checkpoints --- agents/agent.py | 11 ++++++++++- agents/clipped_ppo_agent.py | 2 +- architectures/tensorflow_components/architecture.py | 10 ++++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/agents/agent.py b/agents/agent.py index 2dd6818..ddba3ef 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -117,6 +117,12 @@ class Agent(object): if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers: self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,)) self.running_reward_stats = RunningStat(()) + if self.tp.checkpoint_restore_dir: + checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p") + self.running_observation_stats = read_pickle(checkpoint_path) + else: + self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,)) + self.running_reward_stats = RunningStat(()) else: self.running_observation_stats = SharedRunningStats(self.tp, replicated_device, shape=(self.tp.env.desired_observation_width,), @@ -247,7 +253,7 @@ class Agent(object): return observation.astype('uint8') else: - if self.tp.env.normalize_observation: + if self.tp.env.normalize_observation and self.sess is not None: # standardize the input observation using a running mean and std if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers: self.running_observation_stats.push(observation) @@ -544,6 +550,9 @@ class Agent(object): if current_snapshot_period > model_snapshots_periods_passed: model_snapshots_periods_passed = current_snapshot_period self.save_model(model_snapshots_periods_passed) + to_pickle(self.running_observation_stats, + os.path.join(self.tp.save_model_dir, + "running_stats.p".format(model_snapshots_periods_passed))) # play and record in replay buffer if self.tp.agent.collect_new_data: diff --git a/agents/clipped_ppo_agent.py b/agents/clipped_ppo_agent.py index 88a70b0..f051b31 100644 --- a/agents/clipped_ppo_agent.py +++ b/agents/clipped_ppo_agent.py @@ -176,7 +176,7 @@ class ClippedPPOAgent(ActorCriticAgent): dataset = dataset[:self.tp.agent.num_consecutive_playing_steps] if self.tp.distributed and self.tp.agent.share_statistics_between_workers: - self.running_observation_stats.push(np.array([t.state['observation'] for t in dataset])) + self.running_observation_stats.push(np.array([np.array(t.state['observation']) for t in dataset])) losses = self.train_network(dataset, 10) self.value_loss.add_sample(losses[0]) diff --git a/architectures/tensorflow_components/architecture.py b/architectures/tensorflow_components/architecture.py index 3474ff4..006ed2c 100644 --- a/architectures/tensorflow_components/architecture.py +++ b/architectures/tensorflow_components/architecture.py @@ -125,9 +125,10 @@ class TensorFlowArchitecture(Architecture): self.update_weights_from_batch_gradients = self.optimizer.apply_gradients( zip(self.weights_placeholders, self.trainable_weights), global_step=self.global_step) - current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, - scope=tf.contrib.framework.get_name_scope()) - self.merged = tf.summary.merge(current_scope_summaries) + if self.tp.visualization.tensorboard: + current_scope_summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, + scope=tf.contrib.framework.get_name_scope()) + self.merged = tf.summary.merge(current_scope_summaries) # initialize or restore model if not self.tp.distributed: @@ -197,7 +198,8 @@ class TensorFlowArchitecture(Architecture): feed_dict[self.middleware_embedder.c_in] = self.middleware_embedder.c_init feed_dict[self.middleware_embedder.h_in] = self.middleware_embedder.h_init - fetches += [self.merged] + if self.tp.visualization.tensorboard: + fetches += [self.merged] # get grads result = self.tp.sess.run(fetches, feed_dict=feed_dict)