mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
@@ -925,30 +925,31 @@ class Agent(AgentInterface):
|
||||
self.input_filter.observation_filters['attention'].crop_high = action[1]
|
||||
self.output_filter.action_filters['masking'].set_masking(action[0], action[1])
|
||||
|
||||
def save_checkpoint(self, checkpoint_id: int) -> None:
|
||||
def save_checkpoint(self, checkpoint_prefix: str) -> None:
|
||||
"""
|
||||
Allows agents to store additional information when saving checkpoints.
|
||||
|
||||
:param checkpoint_id: the id of the checkpoint
|
||||
:param checkpoint_prefix: The prefix of the checkpoint file to save
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
|
||||
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
checkpoint_dir = self.ap.task_parameters.checkpoint_save_dir
|
||||
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix] + self.full_name_id.split('/')) # adds both level name and agent name
|
||||
|
||||
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||
"""
|
||||
Allows agents to store additional information when saving checkpoints.
|
||||
|
||||
:param checkpoint_id: the id of the checkpoint
|
||||
:param checkpoint_dir: The checkpoint dir to restore from
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir,
|
||||
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
checkpoint_prefix = '.'.join(self.full_name_id.split('/')) # adds both level name and agent name
|
||||
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
# no output filters currently have an internal state to restore
|
||||
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||
|
||||
@@ -98,10 +98,10 @@ class AgentInterface(object):
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def save_checkpoint(self, checkpoint_id: int) -> None:
|
||||
def save_checkpoint(self, checkpoint_prefix: str) -> None:
|
||||
"""
|
||||
Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc.
|
||||
:param checkpoint_id: the checkpoint id to use for saving
|
||||
:param checkpoint_prefix: The prefix of the checkpoint file to save
|
||||
:return: None
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@@ -389,8 +389,8 @@ class CompositeAgent(AgentInterface):
|
||||
# probably better to only return the agents' goal_reached decisions.
|
||||
return episode_ended
|
||||
|
||||
def save_checkpoint(self, checkpoint_id: int) -> None:
|
||||
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
|
||||
def save_checkpoint(self, checkpoint_prefix: str) -> None:
|
||||
[agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()]
|
||||
|
||||
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
|
||||
|
||||
@@ -203,7 +203,7 @@ class NECAgent(ValueOptimizationAgent):
|
||||
self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
|
||||
actions, discounted_rewards)
|
||||
|
||||
def save_checkpoint(self, checkpoint_id):
|
||||
super().save_checkpoint(checkpoint_id)
|
||||
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
||||
def save_checkpoint(self, checkpoint_prefix):
|
||||
super().save_checkpoint(checkpoint_prefix)
|
||||
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_prefix) + '.dnd'), 'wb') as f:
|
||||
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
Reference in New Issue
Block a user