mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
adding pickling of running_stats and updating the beta entropy for ClippedPPO
This commit is contained in:
@@ -100,8 +100,12 @@ class Agent:
|
|||||||
|
|
||||||
if self.tp.env.normalize_observation and not self.env.is_state_type_image:
|
if self.tp.env.normalize_observation and not self.env.is_state_type_image:
|
||||||
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
if not self.tp.distributed or not self.tp.agent.share_statistics_between_workers:
|
||||||
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
if self.tp.checkpoint_restore_dir:
|
||||||
self.running_reward_stats = RunningStat(())
|
checkpoint_path = os.path.join(self.tp.checkpoint_restore_dir, "running_stats.p")
|
||||||
|
self.running_observation_stats = read_pickle(checkpoint_path)
|
||||||
|
else:
|
||||||
|
self.running_observation_stats = RunningStat((self.tp.env.desired_observation_width,))
|
||||||
|
self.running_reward_stats = RunningStat(())
|
||||||
else:
|
else:
|
||||||
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
self.running_observation_stats = SharedRunningStats(self.tp, replicated_device,
|
||||||
shape=(self.tp.env.desired_observation_width,),
|
shape=(self.tp.env.desired_observation_width,),
|
||||||
@@ -515,6 +519,9 @@ class Agent:
|
|||||||
if current_snapshot_period > model_snapshots_periods_passed:
|
if current_snapshot_period > model_snapshots_periods_passed:
|
||||||
model_snapshots_periods_passed = current_snapshot_period
|
model_snapshots_periods_passed = current_snapshot_period
|
||||||
self.main_network.save_model(model_snapshots_periods_passed)
|
self.main_network.save_model(model_snapshots_periods_passed)
|
||||||
|
to_pickle(self.running_observation_stats,
|
||||||
|
os.path.join(self.tp.save_model_dir,
|
||||||
|
"running_stats.p".format(model_snapshots_periods_passed)))
|
||||||
|
|
||||||
# play and record in replay buffer
|
# play and record in replay buffer
|
||||||
if self.tp.agent.step_until_collecting_full_episodes:
|
if self.tp.agent.step_until_collecting_full_episodes:
|
||||||
|
|||||||
@@ -442,7 +442,7 @@ class ClippedPPO(AgentParameters):
|
|||||||
batch_size = 64
|
batch_size = 64
|
||||||
use_separate_networks_per_head = True
|
use_separate_networks_per_head = True
|
||||||
step_until_collecting_full_episodes = True
|
step_until_collecting_full_episodes = True
|
||||||
beta_entropy = 0.01
|
beta_entropy = 0.001
|
||||||
|
|
||||||
class DFP(AgentParameters):
|
class DFP(AgentParameters):
|
||||||
type = 'DFPAgent'
|
type = 'DFPAgent'
|
||||||
|
|||||||
Reference in New Issue
Block a user