mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Fixes for having NumpySharedRunningStats syncing on multi-node (#139)
1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
@@ -66,19 +66,20 @@ class Filter(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None:
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_id: the checkpoint's ID
|
||||
:param checkpoint_dir: the directory in which to save the filter's state
|
||||
:param checkpoint_prefix: the prefix of the checkpoint file to save
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_dir: the directory from which to restore
|
||||
:param checkpoint_prefix: the checkpoint prefix to look for
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
@@ -221,15 +222,25 @@ class OutputFilter(Filter):
|
||||
"""
|
||||
del self._action_filters[filter_name]
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix):
|
||||
"""
|
||||
Currently not in use for OutputFilter.
|
||||
:param checkpoint_dir:
|
||||
:param checkpoint_id:
|
||||
:param checkpoint_dir: the directory in which to save the filter's state
|
||||
:param checkpoint_prefix: the prefix of the checkpoint file to save
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||
"""
|
||||
Currently not in use for OutputFilter.
|
||||
:param checkpoint_dir: the directory from which to restore
|
||||
:param checkpoint_prefix: the checkpoint prefix to look for
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class NoOutputFilter(OutputFilter):
|
||||
"""
|
||||
@@ -444,40 +455,45 @@ class InputFilter(Filter):
|
||||
"""
|
||||
del self._reward_filters[filter_name]
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix):
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_id: the checkpoint's ID
|
||||
:param checkpoint_dir: the directory in which to save the filter's state
|
||||
:param checkpoint_prefix: the prefix of the checkpoint file to save
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters'])
|
||||
if self.name is not None:
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id)
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name,
|
||||
filter_name), checkpoint_id)
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
filter_name])
|
||||
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_dir: the directory from which to restore
|
||||
:param checkpoint_prefix: the checkpoint prefix to look for
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters'])
|
||||
if self.name is not None:
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name))
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters',
|
||||
observation_name, filter_name))
|
||||
checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
|
||||
filter_name])
|
||||
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
|
||||
class NoInputFilter(InputFilter):
|
||||
|
||||
@@ -82,11 +82,8 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
clip_values=(self.clip_min, self.clip_max))
|
||||
return input_observation_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir)
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
@@ -75,8 +75,8 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
return input_reward_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
|
||||
self.running_rewards_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
|
||||
|
||||
Reference in New Issue
Block a user