1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Fixes for having NumpySharedRunningStats syncing on multi-node (#139)

1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3.
2. Removing the reference to Redis so that it won't try to pickle that in.
3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
Gal Leibovich
2018-11-23 16:11:47 +02:00
committed by GitHub
parent 87a7848b0a
commit a1c56edd98
12 changed files with 154 additions and 99 deletions

View File

@@ -925,30 +925,31 @@ class Agent(AgentInterface):
self.input_filter.observation_filters['attention'].crop_high = action[1]
self.output_filter.action_filters['masking'].set_masking(action[0], action[1])
def save_checkpoint(self, checkpoint_id: int) -> None:
def save_checkpoint(self, checkpoint_prefix: str) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:param checkpoint_prefix: The prefix of the checkpoint file to save
:return: None
"""
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
checkpoint_dir = self.ap.task_parameters.checkpoint_save_dir
checkpoint_prefix = '.'.join([checkpoint_prefix] + self.full_name_id.split('/')) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
def restore_checkpoint(self, checkpoint_dir: str) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:param checkpoint_dir: The checkpoint dir to restore from
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
checkpoint_prefix = '.'.join(self.full_name_id.split('/')) # adds both level name and agent name
self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
# no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)