diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index eb5bf62..21be6be 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -925,30 +925,31 @@ class Agent(AgentInterface): self.input_filter.observation_filters['attention'].crop_high = action[1] self.output_filter.action_filters['masking'].set_masking(action[0], action[1]) - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Allows agents to store additional information when saving checkpoints. - :param checkpoint_id: the id of the checkpoint + :param checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ - checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir, - *(self.full_name_id.split('/'))) # adds both level name and agent name - self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + checkpoint_dir = self.ap.task_parameters.checkpoint_save_dir + + checkpoint_prefix = '.'.join([checkpoint_prefix] + self.full_name_id.split('/')) # adds both level name and agent name + + self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) + self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) def restore_checkpoint(self, checkpoint_dir: str) -> None: """ Allows agents to store additional information when saving checkpoints. - :param checkpoint_id: the id of the checkpoint + :param checkpoint_dir: The checkpoint dir to restore from :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, - *(self.full_name_id.split('/'))) # adds both level name and agent name - self.input_filter.restore_state_from_checkpoint(checkpoint_dir) - self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir) + checkpoint_prefix = '.'.join(self.full_name_id.split('/')) # adds both level name and agent name + self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) + self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) # no output filters currently have an internal state to restore # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/agents/agent_interface.py b/rl_coach/agents/agent_interface.py index 0a7aaab..f3b7903 100644 --- a/rl_coach/agents/agent_interface.py +++ b/rl_coach/agents/agent_interface.py @@ -98,10 +98,10 @@ class AgentInterface(object): """ raise NotImplementedError("") - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc. - :param checkpoint_id: the checkpoint id to use for saving + :param checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ raise NotImplementedError("") diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index 354c36a..33b437b 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -389,8 +389,8 @@ class CompositeAgent(AgentInterface): # probably better to only return the agents' goal_reached decisions. return episode_ended - def save_checkpoint(self, checkpoint_id: int) -> None: - [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + def save_checkpoint(self, checkpoint_prefix: str) -> None: + [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()] def restore_checkpoint(self, checkpoint_dir: str) -> None: [agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()] diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index 4baf08a..9eabb78 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -203,7 +203,7 @@ class NECAgent(ValueOptimizationAgent): self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings, actions, discounted_rewards) - def save_checkpoint(self, checkpoint_id): - super().save_checkpoint(checkpoint_id) - with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: + def save_checkpoint(self, checkpoint_prefix): + super().save_checkpoint(checkpoint_prefix) + with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_prefix) + '.dnd'), 'wb') as f: pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index 4a38e3e..fe805af 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import pickle import numpy as np import tensorflow as tf @@ -46,10 +48,10 @@ class TFSharedRunningStats(SharedRunningStats): initializer=tf.constant_initializer(0.0), name="running_sum", trainable=False, shape=shape, validate_shape=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) - self._sum_squared = tf.get_variable( + self._sum_squares = tf.get_variable( dtype=tf.float64, initializer=tf.constant_initializer(self.epsilon), - name="running_sum_squared", trainable=False, shape=shape, validate_shape=False, + name="running_sum_squares", trainable=False, shape=shape, validate_shape=False, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) self._count = tf.get_variable( dtype=tf.float64, @@ -59,17 +61,17 @@ class TFSharedRunningStats(SharedRunningStats): self._shape = None self._mean = tf.div(self._sum, self._count, name="mean") - self._std = tf.sqrt(tf.maximum((self._sum_squared - self._count*tf.square(self._mean)) + self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean)) / tf.maximum(self._count-1, 1), self.epsilon), name="stdev") self.tf_mean = tf.cast(self._mean, 'float32') self.tf_std = tf.cast(self._std, 'float32') self.new_sum = tf.placeholder(dtype=tf.float64, name='sum') - self.new_sum_squared = tf.placeholder(dtype=tf.float64, name='var') + self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var') self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True) - self._inc_sum_squared = tf.assign_add(self._sum_squared, self.new_sum_squared, use_locking=True) + self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True) self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True) self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs') @@ -84,10 +86,10 @@ class TFSharedRunningStats(SharedRunningStats): def push_val(self, x): x = x.astype('float64') - self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], + self.sess.run([self._inc_sum, self._inc_sum_squares, self._inc_count], feed_dict={ self.new_sum: x.sum(axis=0).ravel(), - self.new_sum_squared: np.square(x).sum(axis=0).ravel(), + self.new_sum_squares: np.square(x).sum(axis=0).ravel(), self.newcount: np.array(len(x), dtype='float64') }) if self._shape is None: @@ -117,11 +119,11 @@ class TFSharedRunningStats(SharedRunningStats): def shape(self, val): self._shape = val self.new_sum.set_shape(val) - self.new_sum_squared.set_shape(val) + self.new_sum_squares.set_shape(val) self.tf_mean.set_shape(val) self.tf_std.set_shape(val) self._sum.set_shape(val) - self._sum_squared.set_shape(val) + self._sum_squares.set_shape(val) def normalize(self, batch): if self.clip_values is not None: @@ -129,10 +131,25 @@ class TFSharedRunningStats(SharedRunningStats): else: return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - # the stats are part of the TF graph - no need to explicitly save anything - pass + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Nevertheless, if we'll want to restore a checkpoint back to either a * single-worker, or a + # multi-node-multi-worker * run, we have to save the internal state, so that it can be restored to the + # NumpySharedRunningStats implementation. - def restore_state_from_checkpoint(self, checkpoint_dir: str): - # the stats are part of the TF graph - no need to explicitly restore anything + dict_to_save = {'_mean': self.mean, + '_std': self.std, + '_count': self.n, + '_sum': self.sess.run(self._sum), + '_sum_squares': self.sess.run(self._sum_squares)} + + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) + + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + # Since the internal state is maintained as part of the TF graph, no need to do anything special for + # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run. + # Restoring from either a * single-worker, or a multi-node-multi-worker * run, to a single-node-multi-thread run + # is not supported. pass diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index 28a195f..6ad2d55 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -66,19 +66,20 @@ class Filter(object): """ pass - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None: + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter - :param checkpoint_id: the checkpoint's ID + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: None """ pass - def restore_state_from_checkpoint(self, checkpoint_dir)->None: + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for :return: None """ pass @@ -221,15 +222,25 @@ class OutputFilter(Filter): """ del self._action_filters[filter_name] - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix): """ Currently not in use for OutputFilter. - :param checkpoint_dir: - :param checkpoint_id: + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: """ pass + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: + """ + Currently not in use for OutputFilter. + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for + :return: None + """ + pass + + class NoOutputFilter(OutputFilter): """ @@ -444,40 +455,45 @@ class InputFilter(Filter): """ del self._reward_filters[filter_name] - def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix): """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter - :param checkpoint_id: the checkpoint's ID + :param checkpoint_dir: the directory in which to save the filter's state + :param checkpoint_prefix: the prefix of the checkpoint file to save :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters']) if self.name is not None: - checkpoint_dir = os.path.join(checkpoint_dir, self.name) + checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name, - filter_name), checkpoint_id) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + filter_name]) + filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - def restore_state_from_checkpoint(self, checkpoint_dir)->None: + def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None: """ Save the filter's internal state to a checkpoint to file, so that it can be later restored. - :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_dir: the directory from which to restore + :param checkpoint_prefix: the checkpoint prefix to look for :return: None """ - checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters']) if self.name is not None: - checkpoint_dir = os.path.join(checkpoint_dir, self.name) + checkpoint_prefix = '.'.join([checkpoint_prefix, self.name]) for filter_name, filter in self._reward_filters.items(): - filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name)) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name]) + filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) for observation_name, filters_dict in self._observation_filters.items(): for filter_name, filter in filters_dict.items(): - filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', - observation_name, filter_name)) + checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name, + filter_name]) + filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) class NoInputFilter(InputFilter): diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 6edd6c9..db9e104 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -82,11 +82,8 @@ class ObservationNormalizationFilter(ObservationFilter): clip_values=(self.clip_min, self.clip_max)) return input_observation_space - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) - - def restore_state_from_checkpoint(self, checkpoint_dir: str): - self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir) + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index 1541966..b708c93 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -75,8 +75,8 @@ class RewardNormalizationFilter(RewardFilter): def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: return input_reward_space - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix) - self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) \ No newline at end of file + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + self.running_rewards_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix) diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index 180681d..518c388 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -581,10 +581,13 @@ class GraphManager(object): def save_checkpoint(self): if self.task_parameters.checkpoint_save_dir is None: self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') - checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, - "{}_Step-{}.ckpt".format( + + filename = "{}_Step-{}.ckpt".format( self.checkpoint_id, - self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) + self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps]) + + checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir, + filename) if not os.path.exists(os.path.dirname(checkpoint_path)): os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure if not isinstance(self.task_parameters, DistributedTaskParameters): @@ -593,7 +596,7 @@ class GraphManager(object): saved_checkpoint_path = checkpoint_path # this is required in order for agents to save additional information like a DND for example - [manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] + [manager.save_checkpoint(filename) for manager in self.level_managers] # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used if self.task_parameters.export_onnx_graph: diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 1a344f1..91a7a19 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -248,12 +248,13 @@ class LevelManager(EnvironmentInterface): return env_response_for_upper_level - def save_checkpoint(self, checkpoint_id: int) -> None: + def save_checkpoint(self, checkpoint_prefix: str) -> None: """ Save checkpoints of the networks of all agents + :param: checkpoint_prefix: The prefix of the checkpoint file to save :return: None """ - [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()] def restore_checkpoint(self, checkpoint_dir: str) -> None: """ diff --git a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py index 456bc96..3368ee8 100644 --- a/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py +++ b/rl_coach/memories/non_episodic/differentiable_neural_dictionary.py @@ -269,14 +269,19 @@ class QDND(object): def load_dnd(model_dir): - max_id = 0 + latest_checkpoint_id = -1 + latest_checkpoint = '' + # get all checkpoint files + for fname in os.listdir(model_dir): + path = os.path.join(model_dir, fname) + if os.path.isdir(path) or fname.split('.')[-1] != 'srs': + continue + checkpoint_id = int(fname.split('_')[0]) + if checkpoint_id > latest_checkpoint_id: + latest_checkpoint = fname + latest_checkpoint_id = checkpoint_id - for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]: - if int(f.split('.')[0]) > max_id: - max_id = int(f.split('.')[0]) - - model_path = str(max_id) + '.dnd' - with open(os.path.join(model_dir, model_path), 'rb') as f: + with open(os.path.join(model_dir, str(latest_checkpoint)), 'rb') as f: DND = pickle.load(f) for a in range(DND.num_actions): diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index c76f232..7f1176f 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -21,7 +21,6 @@ import redis import numpy as np - class SharedRunningStatsSubscribe(threading.Thread): def __init__(self, shared_running_stats): super().__init__() @@ -103,13 +102,28 @@ class SharedRunningStats(ABC): pass @abstractmethod - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int): pass @abstractmethod - def restore_state_from_checkpoint(self, checkpoint_dir: str): + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): pass + def get_latest_checkpoint(self, checkpoint_dir: str) -> str: + latest_checkpoint_id = -1 + latest_checkpoint = '' + # get all checkpoint files + for fname in os.listdir(checkpoint_dir): + path = os.path.join(checkpoint_dir, fname) + if os.path.isdir(path) or fname.split('.')[-1] != 'srs': + continue + checkpoint_id = int(fname.split('_')[0]) + if checkpoint_id > latest_checkpoint_id: + latest_checkpoint = fname + latest_checkpoint_id = checkpoint_id + + return latest_checkpoint + class NumpySharedRunningStats(SharedRunningStats): def __init__(self, name, epsilon=1e-2, pubsub_params=None): @@ -164,21 +178,22 @@ class NumpySharedRunningStats(SharedRunningStats): # no session for the numpy implementation pass - def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): - with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f: - pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int): + dict_to_save = {'_mean': self._mean, + '_std': self._std, + '_count': self._count, + '_sum': self._sum, + '_sum_squares': self._sum_squares} - def restore_state_from_checkpoint(self, checkpoint_dir: str): - latest_checkpoint = -1 - # get all checkpoint files - for fname in os.listdir(checkpoint_dir): - path = os.path.join(checkpoint_dir, fname) - if os.path.isdir(path): - continue - checkpoint_id = int(fname.split('.')[0]) - if checkpoint_id > latest_checkpoint: - latest_checkpoint = checkpoint_id + with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f: + pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL) - with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f: - temp_running_observation_stats = pickle.load(f) - self.__dict__.update(temp_running_observation_stats) + def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str): + latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir) + + if latest_checkpoint_filename == '': + raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ") + + with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f: + saved_dict = pickle.load(f) + self.__dict__.update(saved_dict)