diff --git a/agents/agent.py b/agents/agent.py index 0279efd..d6978bc 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -100,8 +100,12 @@ class Agent: if self.tp.env.normalize_observation and not self.env.is_state_type_image: if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers: - self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,)) - self.running_reward_stats = RunningStat(()) + if self.tp.checkpoint_restore_dir: + checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p") + self.running_observation_stats = read_pickle(checkpoint_path) + else: + self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,)) + self.running_reward_stats = RunningStat(()) else: self.running_observation_stats = SharedRunningStats(self.tp, replicated_device, shape=(self.tp.env.desired_observation_width,), @@ -515,6 +519,9 @@ class Agent: if current_snapshot_period > model_snapshots_periods_passed: model_snapshots_periods_passed = current_snapshot_period self.main_network.save_model(model_snapshots_periods_passed) + to_pickle(self.running_observation_stats, + os.path.join(self.tp.save_model_dir, + "running_stats.p".format(model_snapshots_periods_passed))) # play and record in replay buffer if self.tp.agent.step_until_collecting_full_episodes: diff --git a/configurations.py b/configurations.py index aff718a..2fefd45 100644 --- a/configurations.py +++ b/configurations.py @@ -442,7 +442,7 @@ class ClippedPPO(AgentParameters): batch_size = 64 use_separate_networks_per_head = True step_until_collecting_full_episodes = True - beta_entropy = 0.01 + beta_entropy = 0.001 class DFP(AgentParameters): type = 'DFPAgent'