From a1c56edd98898ed25db9c6ed1e05b50a0a85e097 Mon Sep 17 00:00:00 2001 From: Gal Leibovich Date: Fri, 23 Nov 2018 16:11:47 +0200 Subject: [PATCH] Fixes for having NumpySharedRunningStats syncing on multi-node (#139) 1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3. 2. Removing the reference to Redis so that it won't try to pickle that in. 3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run. --- rl_coach/agents/agent.py | 25 ++++---- rl_coach/agents/agent_interface.py | 4 +- rl_coach/agents/composite_agent.py | 4 +- rl_coach/agents/nec_agent.py | 6 +- .../tensorflow_components/shared_variables.py | 45 +++++++++----- rl_coach/filters/filter.py | 62 ++++++++++++------- .../observation_normalization_filter.py | 11 ++-- .../reward/reward_normalization_filter.py | 8 +-- rl_coach/graph_managers/graph_manager.py | 11 ++-- rl_coach/level_manager.py | 5 +- .../differentiable_neural_dictionary.py | 19 +++--- rl_coach/utilities/shared_running_stats.py | 53 ++++++++++------ 12 files changed, 154 insertions(+), 99 deletions(-) diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index eb5bf62..21be6be 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -925,30 +925,31 @@ class Agent(AgentInterface): self.input_filter.observation_filters['attention'].crop_high = action[1] self.output_filter.action_filters['masking'].set_masking(action[0], action[1]) - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Allows agents to store additional information when saving checkpoints. - :param checkpoint_id: the id of the checkpoint + :param checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ - checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir, - *(self.full_name_id.split('/'))) # adds both level name and agent name - self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + checkpoint_dir = self.ap.task_parameters.checkpoint_save_dir + + checkpoint_prefix = '.'.join([checkpoint_prefix] + self.full_name_id.split('/')) # adds both level name and agent name + + self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) def restore_checkpoint(self, checkpoint_dir: str) -> None: """ Allows agents to store additional information when saving checkpoints. - :param checkpoint_id: the id of the checkpoint + :param checkpoint_dir: The checkpoint dir to restore from :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, - *(self.full_name_id.split('/'))) # adds both level name and agent name - self.input_filter.restore_state_from_checkpoint(checkpoint_dir) - self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir) + checkpoint_prefix = '.'.join(self.full_name_id.split('/')) # adds both level name and agent name + self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) # no output filters currently have an internal state to restore # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 0a7aaab..f3b7903 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -98,10 +98,10 @@ class AgentInterface(object): """ raise NotImplementedError("") - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc. - :param checkpoint_id: the checkpoint id to use for saving + :param checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ raise NotImplementedError("") diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index 354c36a..33b437b 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -389,8 +389,8 @@ class CompositeAgent(AgentInterface): # probably better to only return the agents' goal_reached decisions. return episode_ended - def save_checkpoint(self, checkpoint_id: int) -> None: - [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + def save_checkpoint(self, checkpoint_prefix: str) -> None: + [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()] def restore_checkpoint(self, checkpoint_dir: str) -> None: [agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()] diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 4baf08a..9eabb78 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -203,7 +203,7 @@ class NECAgent(ValueOptimizationAgent): self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings, actions, discounted_rewards) - def save_checkpoint(self, checkpoint_id): - super().save_checkpoint(checkpoint_id) - with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: + def save_checkpoint(self, checkpoint_prefix): + super().save_checkpoint(checkpoint_prefix) + with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_prefix) + '.dnd'), 'wb') as f: pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index 4a38e3e..fe805af 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import pickle import numpy as np import tensorflow as tf @@ -46,10 +48,10 @@ class TFSharedRunningStats(SharedRunningStats): initializer=tf.constant_initializer(0.0), name="running_sum", trainable=False, shape=shape, validate_shape=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) - self._sum_squared = tf.get_variable( + self._sum_squares = tf.get_variable( dtype=tf.float64, initializer=tf.constant_initializer(self.epsilon), - name="running_sum_squared", trainable=False, shape=shape, validate_shape=False, + name="running_sum_squares", trainable=False, shape=shape, validate_shape=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) self._count = tf.get_variable( dtype=tf.float64, @@ -59,17 +61,17 @@ class TFSharedRunningStats(SharedRunningStats): self._shape = None self._mean = tf.div(self._sum, self._count, name="mean") - self._std = tf.sqrt(tf.maximum((self._sum_squared - self._count*tf.square(self._mean)) + self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean)) / tf.maximum(self._count-1, 1), self.epsilon), name="stdev") self.tf_mean = tf.cast(self._mean, 'float32') self.tf_std = tf.cast(self._std, 'float32') self.new_sum = tf.placeholder(dtype=tf.float64, name='sum') - self.new_sum_squared = tf.placeholder(dtype=tf.float64, name='var') + self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var') self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True) - self._inc_sum_squared = tf.assign_add(self._sum_squared, self.new_sum_squared, use_locking=True) + self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True) self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True) self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs') @@ -84,10 +86,10 @@ class TFSharedRunningStats(SharedRunningStats): def push_val(self, x): x = x.astype('float64') - self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], + self.sess.run([self._inc_sum, self._inc_sum_squares, self._inc_count], feed_dict={ self.new_sum: x.sum(axis=0).ravel(), - self.new_sum_squared: np.square(x).sum(axis=0).ravel(), + self.new_sum_squares: np.square(x).sum(axis=0).ravel(), self.newcount: np.array(len(x), dtype='float64') }) if self._shape is None: @@ -117,11 +119,11 @@ class TFSharedRunningStats(SharedRunningStats): def shape(self, val): self._shape = val self.new_sum.set_shape(val) - self.new_sum_squared.set_shape(val) + self.new_sum_squares.set_shape(val) self.tf_mean.set_shape(val) self.tf_std.set_shape(val) self._sum.set_shape(val) - self._sum_squared.set_shape(val) + self._sum_squares.set_shape(val) def normalize(self, batch): if self.clip_values is not None: @@ -129,10 +131,25 @@ class TFSharedRunningStats(SharedRunningStats): else: return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - # the stats are part of the TF graph - no need to explicitly save anything - pass + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Nevertheless, if we'll want to restore a checkpoint back to either a * single-worker, or a + # multi-node-multi-worker * run, we have to save the internal state, so that it can be restored to the + # NumpySharedRunningStats implementation. - def restore_state_from_checkpoint(self, checkpoint_dir: str): - # the stats are part of the TF graph - no need to explicitly restore anything + dict_to_save = {'_mean': self.mean, + '_std': self.std, + '_count': self.n, + '_sum': self.sess.run(self._sum), + '_sum_squares': self.sess.run(self._sum_squares)} + + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) + + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Restoring from either a * single-worker, or a multi-node-multi-worker * run, to a single-node-multi-thread run + # is not supported. pass diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 28a195f..6ad2d55 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -66,19 +66,20 @@ class Filter(object): """ pass - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None: + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter - :param checkpoint_id: the checkpoint's ID + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: None """ pass - def restore_state_from_checkpoint(self, checkpoint_dir)->None: + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for :return: None """ pass @@ -221,15 +222,25 @@ class OutputFilter(Filter): """ del self._action_filters[filter_name] - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix): """ Currently not in use for OutputFilter. - :param checkpoint_dir: - :param checkpoint_id: + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: """ pass + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: + """ + Currently not in use for OutputFilter. + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for + :return: None + """ + pass + + class NoOutputFilter(OutputFilter): """ @@ -444,40 +455,45 @@ class InputFilter(Filter): """ del self._reward_filters[filter_name] - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix): """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter - :param checkpoint_id: the checkpoint's ID + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters']) if self.name is not None: - checkpoint_dir = os.path.join(checkpoint_dir, self.name) + checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name, - filter_name), checkpoint_id) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + filter_name]) + filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - def restore_state_from_checkpoint(self, checkpoint_dir)->None: + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters']) if self.name is not None: - checkpoint_dir = os.path.join(checkpoint_dir, self.name) + checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name)) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', - observation_name, filter_name)) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + filter_name]) + filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) class NoInputFilter(InputFilter): diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 6edd6c9..db9e104 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -82,11 +82,8 @@ class ObservationNormalizationFilter(ObservationFilter): clip_values=(self.clip_min, self.clip_max)) return input_observation_space - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - - def restore_state_from_checkpoint(self, checkpoint_dir: str): - self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir) + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index 1541966..b708c93 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -75,8 +75,8 @@ class RewardNormalizationFilter(RewardFilter): def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: return input_reward_space - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) \ No newline at end of file + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_rewards_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 180681d..518c388 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -581,10 +581,13 @@ class GraphManager(object): def save_checkpoint(self): if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') - checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, - "{}_Step-{}.ckpt".format( + + filename = "{}_Step-{}.ckpt".format( self.checkpoint_id, - self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) + self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) + + checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, + filename) if not os.path.exists(os.path.dirname(checkpoint_path)): os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure if not isinstance(self.task_parameters, DistributedTaskParameters): @@ -593,7 +596,7 @@ class GraphManager(object): saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example - [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] + [manager.save_checkpoint(filename) for manager in self.level_managers] # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 1a344f1..91a7a19 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -248,12 +248,13 @@ class LevelManager(EnvironmentInterface): return env_response_for_upper_level - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Save checkpoints of the networks of all agents + :param: checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ - [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()] def restore_checkpoint(self, checkpoint_dir: str) -> None: """ diff --git a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py index 456bc96..3368ee8 100644 --- a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py +++ b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py @@ -269,14 +269,19 @@ class QDND(object): def load_dnd(model_dir): - max_id = 0 + latest_checkpoint_id = -1 + latest_checkpoint = '' + # get all checkpoint files + for fname in os.listdir(model_dir): + path = os.path.join(model_dir, fname) + if os.path.isdir(path) or fname.split('.')[-1] != 'srs': + continue + checkpoint_id = int(fname.split('_')[0]) + if checkpoint_id > latest_checkpoint_id: + latest_checkpoint = fname + latest_checkpoint_id = checkpoint_id - for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]: - if int(f.split('.')[0]) > max_id: - max_id = int(f.split('.')[0]) - - model_path = str(max_id) + '.dnd' - with open(os.path.join(model_dir, model_path), 'rb') as f: + with open(os.path.join(model_dir, str(latest_checkpoint)), 'rb') as f: DND = pickle.load(f) for a in range(DND.num_actions): diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index c76f232..7f1176f 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -21,7 +21,6 @@ import redis import numpy as np - class SharedRunningStatsSubscribe(threading.Thread): def __init__(self, shared_running_stats): super().__init__() @@ -103,13 +102,28 @@ class SharedRunningStats(ABC): pass @abstractmethod - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int): pass @abstractmethod - def restore_state_from_checkpoint(self, checkpoint_dir: str): + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): pass + def get_latest_checkpoint(self, checkpoint_dir: str) -> str: + latest_checkpoint_id = -1 + latest_checkpoint = '' + # get all checkpoint files + for fname in os.listdir(checkpoint_dir): + path = os.path.join(checkpoint_dir, fname) + if os.path.isdir(path) or fname.split('.')[-1] != 'srs': + continue + checkpoint_id = int(fname.split('_')[0]) + if checkpoint_id > latest_checkpoint_id: + latest_checkpoint = fname + latest_checkpoint_id = checkpoint_id + + return latest_checkpoint + class NumpySharedRunningStats(SharedRunningStats): def __init__(self, name, epsilon=1e-2, pubsub_params=None): @@ -164,21 +178,22 @@ class NumpySharedRunningStats(SharedRunningStats): # no session for the numpy implementation pass - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f: - pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int): + dict_to_save = {'_mean': self._mean, + '_std': self._std, + '_count': self._count, + '_sum': self._sum, + '_sum_squares': self._sum_squares} - def restore_state_from_checkpoint(self, checkpoint_dir: str): - latest_checkpoint = -1 - # get all checkpoint files - for fname in os.listdir(checkpoint_dir): - path = os.path.join(checkpoint_dir, fname) - if os.path.isdir(path): - continue - checkpoint_id = int(fname.split('.')[0]) - if checkpoint_id > latest_checkpoint: - latest_checkpoint = checkpoint_id + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) - with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f: - temp_running_observation_stats = pickle.load(f) - self.__dict__.update(temp_running_observation_stats) + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir) + + if latest_checkpoint_filename == '': + raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ") + + with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f: + saved_dict = pickle.load(f) + self.__dict__.update(saved_dict)