1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Fixes for having NumpySharedRunningStats syncing on multi-node (#139)

1. Having the standard checkpoint prefix in order for the data store to grab it, and sync it to S3.
2. Removing the reference to Redis so that it won't try to pickle that in.
3. Enable restoring a checkpoint into a single-worker run, which was saved by a single-node-multiple-worker run.
This commit is contained in:
Gal Leibovich
2018-11-23 16:11:47 +02:00
committed by GitHub
parent 87a7848b0a
commit a1c56edd98
12 changed files with 154 additions and 99 deletions

View File

@@ -925,30 +925,31 @@ class Agent(AgentInterface):
self.input_filter.observation_filters['attention'].crop_high = action[1] self.input_filter.observation_filters['attention'].crop_high = action[1]
self.output_filter.action_filters['masking'].set_masking(action[0], action[1]) self.output_filter.action_filters['masking'].set_masking(action[0], action[1])
def save_checkpoint(self, checkpoint_id: int) -> None: def save_checkpoint(self, checkpoint_prefix: str) -> None:
""" """
Allows agents to store additional information when saving checkpoints. Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint :param checkpoint_prefix: The prefix of the checkpoint file to save
:return: None :return: None
""" """
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir, checkpoint_dir = self.ap.task_parameters.checkpoint_save_dir
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) checkpoint_prefix = '.'.join([checkpoint_prefix] + self.full_name_id.split('/')) # adds both level name and agent name
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
def restore_checkpoint(self, checkpoint_dir: str) -> None: def restore_checkpoint(self, checkpoint_dir: str) -> None:
""" """
Allows agents to store additional information when saving checkpoints. Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint :param checkpoint_dir: The checkpoint dir to restore from
:return: None :return: None
""" """
checkpoint_dir = os.path.join(checkpoint_dir, checkpoint_prefix = '.'.join(self.full_name_id.split('/')) # adds both level name and agent name
*(self.full_name_id.split('/'))) # adds both level name and agent name self.input_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
self.input_filter.restore_state_from_checkpoint(checkpoint_dir) self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
# no output filters currently have an internal state to restore # no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir) # self.output_filter.restore_state_from_checkpoint(checkpoint_dir)

View File

@@ -98,10 +98,10 @@ class AgentInterface(object):
""" """
raise NotImplementedError("") raise NotImplementedError("")
def save_checkpoint(self, checkpoint_id: int) -> None: def save_checkpoint(self, checkpoint_prefix: str) -> None:
""" """
Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc. Save the model of the agent to the disk. This can contain the network parameters, the memory of the agent, etc.
:param checkpoint_id: the checkpoint id to use for saving :param checkpoint_prefix: The prefix of the checkpoint file to save
:return: None :return: None
""" """
raise NotImplementedError("") raise NotImplementedError("")

View File

@@ -389,8 +389,8 @@ class CompositeAgent(AgentInterface):
# probably better to only return the agents' goal_reached decisions. # probably better to only return the agents' goal_reached decisions.
return episode_ended return episode_ended
def save_checkpoint(self, checkpoint_id: int) -> None: def save_checkpoint(self, checkpoint_prefix: str) -> None:
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()]
def restore_checkpoint(self, checkpoint_dir: str) -> None: def restore_checkpoint(self, checkpoint_dir: str) -> None:
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()] [agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]

View File

@@ -203,7 +203,7 @@ class NECAgent(ValueOptimizationAgent):
self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings, self.networks['main'].online_network.output_heads[0].DND.add(self.current_episode_state_embeddings,
actions, discounted_rewards) actions, discounted_rewards)
def save_checkpoint(self, checkpoint_id): def save_checkpoint(self, checkpoint_prefix):
super().save_checkpoint(checkpoint_id) super().save_checkpoint(checkpoint_prefix)
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_prefix) + '.dnd'), 'wb') as f:
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import pickle
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
@@ -46,10 +48,10 @@ class TFSharedRunningStats(SharedRunningStats):
initializer=tf.constant_initializer(0.0), initializer=tf.constant_initializer(0.0),
name="running_sum", trainable=False, shape=shape, validate_shape=False, name="running_sum", trainable=False, shape=shape, validate_shape=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES]) collections=[tf.GraphKeys.GLOBAL_VARIABLES])
self._sum_squared = tf.get_variable( self._sum_squares = tf.get_variable(
dtype=tf.float64, dtype=tf.float64,
initializer=tf.constant_initializer(self.epsilon), initializer=tf.constant_initializer(self.epsilon),
name="running_sum_squared", trainable=False, shape=shape, validate_shape=False, name="running_sum_squares", trainable=False, shape=shape, validate_shape=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES]) collections=[tf.GraphKeys.GLOBAL_VARIABLES])
self._count = tf.get_variable( self._count = tf.get_variable(
dtype=tf.float64, dtype=tf.float64,
@@ -59,17 +61,17 @@ class TFSharedRunningStats(SharedRunningStats):
self._shape = None self._shape = None
self._mean = tf.div(self._sum, self._count, name="mean") self._mean = tf.div(self._sum, self._count, name="mean")
self._std = tf.sqrt(tf.maximum((self._sum_squared - self._count*tf.square(self._mean)) self._std = tf.sqrt(tf.maximum((self._sum_squares - self._count * tf.square(self._mean))
/ tf.maximum(self._count-1, 1), self.epsilon), name="stdev") / tf.maximum(self._count-1, 1), self.epsilon), name="stdev")
self.tf_mean = tf.cast(self._mean, 'float32') self.tf_mean = tf.cast(self._mean, 'float32')
self.tf_std = tf.cast(self._std, 'float32') self.tf_std = tf.cast(self._std, 'float32')
self.new_sum = tf.placeholder(dtype=tf.float64, name='sum') self.new_sum = tf.placeholder(dtype=tf.float64, name='sum')
self.new_sum_squared = tf.placeholder(dtype=tf.float64, name='var') self.new_sum_squares = tf.placeholder(dtype=tf.float64, name='var')
self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count') self.newcount = tf.placeholder(shape=[], dtype=tf.float64, name='count')
self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True) self._inc_sum = tf.assign_add(self._sum, self.new_sum, use_locking=True)
self._inc_sum_squared = tf.assign_add(self._sum_squared, self.new_sum_squared, use_locking=True) self._inc_sum_squares = tf.assign_add(self._sum_squares, self.new_sum_squares, use_locking=True)
self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True) self._inc_count = tf.assign_add(self._count, self.newcount, use_locking=True)
self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs') self.raw_obs = tf.placeholder(dtype=tf.float64, name='raw_obs')
@@ -84,10 +86,10 @@ class TFSharedRunningStats(SharedRunningStats):
def push_val(self, x): def push_val(self, x):
x = x.astype('float64') x = x.astype('float64')
self.sess.run([self._inc_sum, self._inc_sum_squared, self._inc_count], self.sess.run([self._inc_sum, self._inc_sum_squares, self._inc_count],
feed_dict={ feed_dict={
self.new_sum: x.sum(axis=0).ravel(), self.new_sum: x.sum(axis=0).ravel(),
self.new_sum_squared: np.square(x).sum(axis=0).ravel(), self.new_sum_squares: np.square(x).sum(axis=0).ravel(),
self.newcount: np.array(len(x), dtype='float64') self.newcount: np.array(len(x), dtype='float64')
}) })
if self._shape is None: if self._shape is None:
@@ -117,11 +119,11 @@ class TFSharedRunningStats(SharedRunningStats):
def shape(self, val): def shape(self, val):
self._shape = val self._shape = val
self.new_sum.set_shape(val) self.new_sum.set_shape(val)
self.new_sum_squared.set_shape(val) self.new_sum_squares.set_shape(val)
self.tf_mean.set_shape(val) self.tf_mean.set_shape(val)
self.tf_std.set_shape(val) self.tf_std.set_shape(val)
self._sum.set_shape(val) self._sum.set_shape(val)
self._sum_squared.set_shape(val) self._sum_squares.set_shape(val)
def normalize(self, batch): def normalize(self, batch):
if self.clip_values is not None: if self.clip_values is not None:
@@ -129,10 +131,25 @@ class TFSharedRunningStats(SharedRunningStats):
else: else:
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
# the stats are part of the TF graph - no need to explicitly save anything # Since the internal state is maintained as part of the TF graph, no need to do anything special for
pass # save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run.
# Nevertheless, if we'll want to restore a checkpoint back to either a * single-worker, or a
# multi-node-multi-worker * run, we have to save the internal state, so that it can be restored to the
# NumpySharedRunningStats implementation.
def restore_state_from_checkpoint(self, checkpoint_dir: str): dict_to_save = {'_mean': self.mean,
# the stats are part of the TF graph - no need to explicitly restore anything '_std': self.std,
'_count': self.n,
'_sum': self.sess.run(self._sum),
'_sum_squares': self.sess.run(self._sum_squares)}
with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
# Since the internal state is maintained as part of the TF graph, no need to do anything special for
# save/restore, when going from single-node-multi-thread run back to a single-node-multi-worker run.
# Restoring from either a * single-worker, or a multi-node-multi-worker * run, to a single-node-multi-thread run
# is not supported.
pass pass

View File

@@ -66,19 +66,20 @@ class Filter(object):
""" """
pass pass
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None: def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
""" """
Save the filter's internal state to a checkpoint to file, so that it can be later restored. Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter :param checkpoint_dir: the directory in which to save the filter's state
:param checkpoint_id: the checkpoint's ID :param checkpoint_prefix: the prefix of the checkpoint file to save
:return: None :return: None
""" """
pass pass
def restore_state_from_checkpoint(self, checkpoint_dir)->None: def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
""" """
Save the filter's internal state to a checkpoint to file, so that it can be later restored. Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter :param checkpoint_dir: the directory from which to restore
:param checkpoint_prefix: the checkpoint prefix to look for
:return: None :return: None
""" """
pass pass
@@ -221,15 +222,25 @@ class OutputFilter(Filter):
""" """
del self._action_filters[filter_name] del self._action_filters[filter_name]
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix):
""" """
Currently not in use for OutputFilter. Currently not in use for OutputFilter.
:param checkpoint_dir: :param checkpoint_dir: the directory in which to save the filter's state
:param checkpoint_id: :param checkpoint_prefix: the prefix of the checkpoint file to save
:return: :return:
""" """
pass pass
def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
"""
Currently not in use for OutputFilter.
:param checkpoint_dir: the directory from which to restore
:param checkpoint_prefix: the checkpoint prefix to look for
:return: None
"""
pass
class NoOutputFilter(OutputFilter): class NoOutputFilter(OutputFilter):
""" """
@@ -444,40 +455,45 @@ class InputFilter(Filter):
""" """
del self._reward_filters[filter_name] del self._reward_filters[filter_name]
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_prefix):
""" """
Save the filter's internal state to a checkpoint to file, so that it can be later restored. Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter :param checkpoint_dir: the directory in which to save the filter's state
:param checkpoint_id: the checkpoint's ID :param checkpoint_prefix: the prefix of the checkpoint file to save
:return: None :return: None
""" """
checkpoint_dir = os.path.join(checkpoint_dir, 'filters') checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters'])
if self.name is not None: if self.name is not None:
checkpoint_dir = os.path.join(checkpoint_dir, self.name) checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items(): for filter_name, filter in self._reward_filters.items():
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id) checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
for observation_name, filters_dict in self._observation_filters.items(): for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items(): for filter_name, filter in filters_dict.items():
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name, checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
filter_name), checkpoint_id) filter_name])
filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
def restore_state_from_checkpoint(self, checkpoint_dir)->None: def restore_state_from_checkpoint(self, checkpoint_dir, checkpoint_prefix)->None:
""" """
Save the filter's internal state to a checkpoint to file, so that it can be later restored. Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter :param checkpoint_dir: the directory from which to restore
:param checkpoint_prefix: the checkpoint prefix to look for
:return: None :return: None
""" """
checkpoint_dir = os.path.join(checkpoint_dir, 'filters') checkpoint_prefix = '.'.join([checkpoint_prefix, 'filters'])
if self.name is not None: if self.name is not None:
checkpoint_dir = os.path.join(checkpoint_dir, self.name) checkpoint_prefix = '.'.join([checkpoint_prefix, self.name])
for filter_name, filter in self._reward_filters.items(): for filter_name, filter in self._reward_filters.items():
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name)) checkpoint_prefix = '.'.join([checkpoint_prefix, 'reward_filters', filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
for observation_name, filters_dict in self._observation_filters.items(): for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items(): for filter_name, filter in filters_dict.items():
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', checkpoint_prefix = '.'.join([checkpoint_prefix, 'observation_filters', observation_name,
observation_name, filter_name)) filter_name])
filter.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
class NoInputFilter(InputFilter): class NoInputFilter(InputFilter):

View File

@@ -82,11 +82,8 @@ class ObservationNormalizationFilter(ObservationFilter):
clip_values=(self.clip_min, self.clip_max)) clip_values=(self.clip_min, self.clip_max))
return input_observation_space return input_observation_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
if not os.path.exists(checkpoint_dir): self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
os.makedirs(checkpoint_dir)
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)
def restore_state_from_checkpoint(self, checkpoint_dir: str):
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir)

View File

@@ -75,8 +75,8 @@ class RewardNormalizationFilter(RewardFilter):
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
return input_reward_space return input_reward_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
if not os.path.exists(checkpoint_dir): self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_prefix)
os.makedirs(checkpoint_dir)
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
self.running_rewards_stats.restore_state_from_checkpoint(checkpoint_dir, checkpoint_prefix)

View File

@@ -581,10 +581,13 @@ class GraphManager(object):
def save_checkpoint(self): def save_checkpoint(self):
if self.task_parameters.checkpoint_save_dir is None: if self.task_parameters.checkpoint_save_dir is None:
self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint') self.task_parameters.checkpoint_save_dir = os.path.join(self.task_parameters.experiment_path, 'checkpoint')
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
"{}_Step-{}.ckpt".format( filename = "{}_Step-{}.ckpt".format(
self.checkpoint_id, self.checkpoint_id,
self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])) self.total_steps_counters[RunPhase.TRAIN][EnvironmentSteps])
checkpoint_path = os.path.join(self.task_parameters.checkpoint_save_dir,
filename)
if not os.path.exists(os.path.dirname(checkpoint_path)): if not os.path.exists(os.path.dirname(checkpoint_path)):
os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure os.mkdir(os.path.dirname(checkpoint_path)) # Create directory structure
if not isinstance(self.task_parameters, DistributedTaskParameters): if not isinstance(self.task_parameters, DistributedTaskParameters):
@@ -593,7 +596,7 @@ class GraphManager(object):
saved_checkpoint_path = checkpoint_path saved_checkpoint_path = checkpoint_path
# this is required in order for agents to save additional information like a DND for example # this is required in order for agents to save additional information like a DND for example
[manager.save_checkpoint(self.checkpoint_id) for manager in self.level_managers] [manager.save_checkpoint(filename) for manager in self.level_managers]
# the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used # the ONNX graph will be stored only if checkpoints are stored and the -onnx flag is used
if self.task_parameters.export_onnx_graph: if self.task_parameters.export_onnx_graph:

View File

@@ -248,12 +248,13 @@ class LevelManager(EnvironmentInterface):
return env_response_for_upper_level return env_response_for_upper_level
def save_checkpoint(self, checkpoint_id: int) -> None: def save_checkpoint(self, checkpoint_prefix: str) -> None:
""" """
Save checkpoints of the networks of all agents Save checkpoints of the networks of all agents
:param: checkpoint_prefix: The prefix of the checkpoint file to save
:return: None :return: None
""" """
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] [agent.save_checkpoint(checkpoint_prefix) for agent in self.agents.values()]
def restore_checkpoint(self, checkpoint_dir: str) -> None: def restore_checkpoint(self, checkpoint_dir: str) -> None:
""" """

View File

@@ -269,14 +269,19 @@ class QDND(object):
def load_dnd(model_dir): def load_dnd(model_dir):
max_id = 0 latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(model_dir):
path = os.path.join(model_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
latest_checkpoint = fname
latest_checkpoint_id = checkpoint_id
for f in [s for s in os.listdir(model_dir) if s.endswith('.dnd')]: with open(os.path.join(model_dir, str(latest_checkpoint)), 'rb') as f:
if int(f.split('.')[0]) > max_id:
max_id = int(f.split('.')[0])
model_path = str(max_id) + '.dnd'
with open(os.path.join(model_dir, model_path), 'rb') as f:
DND = pickle.load(f) DND = pickle.load(f)
for a in range(DND.num_actions): for a in range(DND.num_actions):

View File

@@ -21,7 +21,6 @@ import redis
import numpy as np import numpy as np
class SharedRunningStatsSubscribe(threading.Thread): class SharedRunningStatsSubscribe(threading.Thread):
def __init__(self, shared_running_stats): def __init__(self, shared_running_stats):
super().__init__() super().__init__()
@@ -103,13 +102,28 @@ class SharedRunningStats(ABC):
pass pass
@abstractmethod @abstractmethod
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
pass pass
@abstractmethod @abstractmethod
def restore_state_from_checkpoint(self, checkpoint_dir: str): def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
pass pass
def get_latest_checkpoint(self, checkpoint_dir: str) -> str:
latest_checkpoint_id = -1
latest_checkpoint = ''
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path) or fname.split('.')[-1] != 'srs':
continue
checkpoint_id = int(fname.split('_')[0])
if checkpoint_id > latest_checkpoint_id:
latest_checkpoint = fname
latest_checkpoint_id = checkpoint_id
return latest_checkpoint
class NumpySharedRunningStats(SharedRunningStats): class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None): def __init__(self, name, epsilon=1e-2, pubsub_params=None):
@@ -164,21 +178,22 @@ class NumpySharedRunningStats(SharedRunningStats):
# no session for the numpy implementation # no session for the numpy implementation
pass pass
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: int):
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f: dict_to_save = {'_mean': self._mean,
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) '_std': self._std,
'_count': self._count,
'_sum': self._sum,
'_sum_squares': self._sum_squares}
def restore_state_from_checkpoint(self, checkpoint_dir: str): with open(os.path.join(checkpoint_dir, str(checkpoint_prefix) + '.srs'), 'wb') as f:
latest_checkpoint = -1 pickle.dump(dict_to_save, f, pickle.HIGHEST_PROTOCOL)
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path):
continue
checkpoint_id = int(fname.split('.')[0])
if checkpoint_id > latest_checkpoint:
latest_checkpoint = checkpoint_id
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f: def restore_state_from_checkpoint(self, checkpoint_dir: str, checkpoint_prefix: str):
temp_running_observation_stats = pickle.load(f) latest_checkpoint_filename = self.get_latest_checkpoint(checkpoint_dir)
self.__dict__.update(temp_running_observation_stats)
if latest_checkpoint_filename == '':
raise ValueError("Could not find NumpySharedRunningStats checkpoint file. ")
with open(os.path.join(checkpoint_dir, str(latest_checkpoint_filename)), 'rb') as f:
saved_dict = pickle.load(f)
self.__dict__.update(saved_dict)