1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

adding pickling of running_stats and updating the beta entropy for ClippedPPO

This commit is contained in:
itaicaspi-intel
2018-03-19 14:37:05 +02:00
parent f7979b05e4
commit 24a0f24279
2 changed files with 10 additions and 3 deletions

View File

@@ -100,8 +100,12 @@ class Agent:
if self.tp.env.normalize_observation and not self.env.is_state_type_image: if self.tp.env.normalize_observation and not self.env.is_state_type_image:
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers: if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,)) if self.tp.checkpoint_restore_dir:
self.running_reward_stats = RunningStat(()) checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p")
self.running_observation_stats = read_pickle(checkpoint_path)
else:
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
self.running_reward_stats = RunningStat(())
else: else:
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device, self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
shape=(self.tp.env.desired_observation_width,), shape=(self.tp.env.desired_observation_width,),
@@ -515,6 +519,9 @@ class Agent:
if current_snapshot_period > model_snapshots_periods_passed: if current_snapshot_period > model_snapshots_periods_passed:
model_snapshots_periods_passed = current_snapshot_period model_snapshots_periods_passed = current_snapshot_period
self.main_network.save_model(model_snapshots_periods_passed) self.main_network.save_model(model_snapshots_periods_passed)
to_pickle(self.running_observation_stats,
os.path.join(self.tp.save_model_dir,
"running_stats.p".format(model_snapshots_periods_passed)))
# play and record in replay buffer # play and record in replay buffer
if self.tp.agent.step_until_collecting_full_episodes: if self.tp.agent.step_until_collecting_full_episodes:

View File

@@ -442,7 +442,7 @@ class ClippedPPO(AgentParameters):
batch_size = 64 batch_size = 64
use_separate_networks_per_head = True use_separate_networks_per_head = True
step_until_collecting_full_episodes = True step_until_collecting_full_episodes = True
beta_entropy = 0.01 beta_entropy = 0.001
class DFP(AgentParameters): class DFP(AgentParameters):
type = 'DFPAgent' type = 'DFPAgent'