mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 11:10:20 +01:00
Save filters' internal state (#127)
* save filters internal state * moving the restore to be made from within NumpyRunningStats
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Dict, List, Union, Tuple
|
from typing import Dict, List, Union, Tuple
|
||||||
@@ -108,8 +109,12 @@ class Agent(AgentInterface):
|
|||||||
|
|
||||||
# filters
|
# filters
|
||||||
self.input_filter = self.ap.input_filter
|
self.input_filter = self.ap.input_filter
|
||||||
|
self.input_filter.set_name('input_filter')
|
||||||
self.output_filter = self.ap.output_filter
|
self.output_filter = self.ap.output_filter
|
||||||
|
self.output_filter.set_name('output_filter')
|
||||||
self.pre_network_filter = self.ap.pre_network_filter
|
self.pre_network_filter = self.ap.pre_network_filter
|
||||||
|
self.pre_network_filter.set_name('pre_network_filter')
|
||||||
|
|
||||||
device = self.replicated_device if self.replicated_device else self.worker_device
|
device = self.replicated_device if self.replicated_device else self.worker_device
|
||||||
|
|
||||||
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
|
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
|
||||||
@@ -923,7 +928,26 @@ class Agent(AgentInterface):
|
|||||||
:param checkpoint_id: the id of the checkpoint
|
:param checkpoint_id: the id of the checkpoint
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
pass
|
checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
|
||||||
|
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||||
|
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||||
|
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||||
|
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||||
|
|
||||||
|
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Allows agents to store additional information when saving checkpoints.
|
||||||
|
|
||||||
|
:param checkpoint_id: the id of the checkpoint
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
checkpoint_dir = os.path.join(checkpoint_dir,
|
||||||
|
*(self.full_name_id.split('/'))) # adds both level name and agent name
|
||||||
|
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||||
|
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
|
# no output filters currently have an internal state to restore
|
||||||
|
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
|
||||||
|
|
||||||
def sync(self) -> None:
|
def sync(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -392,6 +392,9 @@ class CompositeAgent(AgentInterface):
|
|||||||
def save_checkpoint(self, checkpoint_id: int) -> None:
|
def save_checkpoint(self, checkpoint_id: int) -> None:
|
||||||
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
|
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
|
||||||
|
|
||||||
|
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||||
|
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
|
||||||
|
|
||||||
def set_incoming_directive(self, action: ActionType) -> None:
|
def set_incoming_directive(self, action: ActionType) -> None:
|
||||||
self.incoming_action = action
|
self.incoming_action = action
|
||||||
if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection):
|
if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection):
|
||||||
|
|||||||
@@ -204,5 +204,6 @@ class NECAgent(ValueOptimizationAgent):
|
|||||||
actions, discounted_rewards)
|
actions, discounted_rewards)
|
||||||
|
|
||||||
def save_checkpoint(self, checkpoint_id):
|
def save_checkpoint(self, checkpoint_id):
|
||||||
|
super().save_checkpoint(checkpoint_id)
|
||||||
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
|
||||||
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|||||||
@@ -128,3 +128,11 @@ class TFSharedRunningStats(SharedRunningStats):
|
|||||||
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
|
||||||
else:
|
else:
|
||||||
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||||
|
# the stats are part of the TF graph - no need to explicitly save anything
|
||||||
|
pass
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||||
|
# the stats are part of the TF graph - no need to explicitly restore anything
|
||||||
|
pass
|
||||||
|
|||||||
@@ -506,7 +506,7 @@ class AgentParameters(Parameters):
|
|||||||
self.input_filter = None
|
self.input_filter = None
|
||||||
self.output_filter = None
|
self.output_filter = None
|
||||||
self.pre_network_filter = NoInputFilter()
|
self.pre_network_filter = NoInputFilter()
|
||||||
self.full_name_id = None # TODO: do we really want to hold this parameter here?
|
self.full_name_id = None
|
||||||
self.name = None
|
self.name = None
|
||||||
self.is_a_highest_level_agent = True
|
self.is_a_highest_level_agent = True
|
||||||
self.is_a_lowest_level_agent = True
|
self.is_a_lowest_level_agent = True
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ def handle_distributed_coach_tasks(graph_manager, args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_distributed_coach_orchestrator(graph_manager, args):
|
def handle_distributed_coach_orchestrator(args):
|
||||||
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
|
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
|
||||||
RunTypeParameters
|
RunTypeParameters
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Dict, Union, List
|
from typing import Dict, Union, List
|
||||||
@@ -25,12 +26,13 @@ from rl_coach.utils import force_list
|
|||||||
|
|
||||||
|
|
||||||
class Filter(object):
|
class Filter(object):
|
||||||
def __init__(self):
|
def __init__(self, name=None):
|
||||||
pass
|
self.name = name
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
Called from reset() and implements the reset logic for the filter.
|
Called from reset() and implements the reset logic for the filter.
|
||||||
|
:param name: the filter's name
|
||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
@@ -64,14 +66,39 @@ class Filter(object):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None:
|
||||||
|
"""
|
||||||
|
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||||
|
:param checkpoint_dir: the directory in which to save the filter
|
||||||
|
:param checkpoint_id: the checkpoint's ID
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||||
|
"""
|
||||||
|
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||||
|
:param checkpoint_dir: the directory in which to save the filter
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_name(self, name: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the filter's name
|
||||||
|
:param name: the filter's name
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
class OutputFilter(Filter):
|
class OutputFilter(Filter):
|
||||||
"""
|
"""
|
||||||
An output filter is a module that filters the output from an agent to the environment.
|
An output filter is a module that filters the output from an agent to the environment.
|
||||||
"""
|
"""
|
||||||
def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None,
|
def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None,
|
||||||
is_a_reference_filter: bool=False):
|
is_a_reference_filter: bool=False, name=None):
|
||||||
super().__init__()
|
super().__init__(name)
|
||||||
|
|
||||||
if action_filters is None:
|
if action_filters is None:
|
||||||
action_filters = OrderedDict([])
|
action_filters = OrderedDict([])
|
||||||
@@ -194,6 +221,15 @@ class OutputFilter(Filter):
|
|||||||
"""
|
"""
|
||||||
del self._action_filters[filter_name]
|
del self._action_filters[filter_name]
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||||
|
"""
|
||||||
|
Currently not in use for OutputFilter.
|
||||||
|
:param checkpoint_dir:
|
||||||
|
:param checkpoint_id:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NoOutputFilter(OutputFilter):
|
class NoOutputFilter(OutputFilter):
|
||||||
"""
|
"""
|
||||||
@@ -209,8 +245,8 @@ class InputFilter(Filter):
|
|||||||
"""
|
"""
|
||||||
def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None,
|
def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None,
|
||||||
reward_filters: Dict[str, 'RewardFilter']=None,
|
reward_filters: Dict[str, 'RewardFilter']=None,
|
||||||
is_a_reference_filter: bool=False):
|
is_a_reference_filter: bool=False, name=None):
|
||||||
super().__init__()
|
super().__init__(name)
|
||||||
if observation_filters is None:
|
if observation_filters is None:
|
||||||
observation_filters = {}
|
observation_filters = {}
|
||||||
if reward_filters is None:
|
if reward_filters is None:
|
||||||
@@ -299,7 +335,6 @@ class InputFilter(Filter):
|
|||||||
|
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
|
|
||||||
def get_filtered_observation_space(self, observation_name: str,
|
def get_filtered_observation_space(self, observation_name: str,
|
||||||
input_observation_space: ObservationSpace) -> ObservationSpace:
|
input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||||
"""
|
"""
|
||||||
@@ -409,12 +444,47 @@ class InputFilter(Filter):
|
|||||||
"""
|
"""
|
||||||
del self._reward_filters[filter_name]
|
del self._reward_filters[filter_name]
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||||
|
"""
|
||||||
|
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||||
|
:param checkpoint_dir: the directory in which to save the filter
|
||||||
|
:param checkpoint_id: the checkpoint's ID
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||||
|
if self.name is not None:
|
||||||
|
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||||
|
for filter_name, filter in self._reward_filters.items():
|
||||||
|
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id)
|
||||||
|
|
||||||
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
|
for filter_name, filter in filters_dict.items():
|
||||||
|
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name,
|
||||||
|
filter_name), checkpoint_id)
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||||
|
"""
|
||||||
|
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||||
|
:param checkpoint_dir: the directory in which to save the filter
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||||
|
if self.name is not None:
|
||||||
|
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||||
|
for filter_name, filter in self._reward_filters.items():
|
||||||
|
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name))
|
||||||
|
|
||||||
|
for observation_name, filters_dict in self._observation_filters.items():
|
||||||
|
for filter_name, filter in filters_dict.items():
|
||||||
|
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters',
|
||||||
|
observation_name, filter_name))
|
||||||
|
|
||||||
|
|
||||||
class NoInputFilter(InputFilter):
|
class NoInputFilter(InputFilter):
|
||||||
"""
|
"""
|
||||||
Creates an empty input filter. Used only for readability when creating the presets
|
Creates an empty input filter. Used only for readability when creating the presets
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(is_a_reference_filter=False)
|
super().__init__(is_a_reference_filter=False, name='no_input_filter')
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -79,3 +81,12 @@ class ObservationNormalizationFilter(ObservationFilter):
|
|||||||
self.running_observation_stats.set_params(shape=input_observation_space.shape,
|
self.running_observation_stats.set_params(shape=input_observation_space.shape,
|
||||||
clip_values=(self.clip_min, self.clip_max))
|
clip_values=(self.clip_min, self.clip_max))
|
||||||
return input_observation_space
|
return input_observation_space
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
os.makedirs(checkpoint_dir)
|
||||||
|
|
||||||
|
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||||
|
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -74,3 +74,9 @@ class RewardNormalizationFilter(RewardFilter):
|
|||||||
|
|
||||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||||
return input_reward_space
|
return input_reward_space
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||||
|
if not os.path.exists(checkpoint_dir):
|
||||||
|
os.makedirs(checkpoint_dir)
|
||||||
|
|
||||||
|
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||||
@@ -565,7 +565,7 @@ class GraphManager(object):
|
|||||||
self.verify_graph_was_created()
|
self.verify_graph_was_created()
|
||||||
|
|
||||||
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
# TODO: find better way to load checkpoints that were saved with a global network into the online network
|
||||||
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir:
|
if self.task_parameters.checkpoint_restore_dir:
|
||||||
|
|
||||||
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
|
||||||
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
|
||||||
@@ -577,6 +577,8 @@ class GraphManager(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
|
||||||
|
|
||||||
|
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
|
||||||
|
|
||||||
def occasionally_save_checkpoint(self):
|
def occasionally_save_checkpoint(self):
|
||||||
# only the chief process saves checkpoints
|
# only the chief process saves checkpoints
|
||||||
if self.task_parameters.checkpoint_save_secs \
|
if self.task_parameters.checkpoint_save_secs \
|
||||||
|
|||||||
@@ -255,6 +255,13 @@ class LevelManager(EnvironmentInterface):
|
|||||||
"""
|
"""
|
||||||
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
|
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
|
||||||
|
|
||||||
|
def restore_checkpoint(self, checkpoint_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Restores checkpoints of the networks of all agents
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
|
||||||
|
|
||||||
def sync(self) -> None:
|
def sync(self) -> None:
|
||||||
"""
|
"""
|
||||||
Sync the networks of the agents with the global network parameters
|
Sync the networks of the agents with the global network parameters
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
|
|||||||
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
|
||||||
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
|
||||||
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
from rl_coach.exploration_policies.e_greedy import EGreedyParameters
|
||||||
|
from rl_coach.filters.filter import InputFilter
|
||||||
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
|
||||||
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
|
||||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||||
@@ -47,6 +48,7 @@ agent_params.algorithm.num_steps_between_copying_online_weights_to_target = Envi
|
|||||||
# Distributed Coach synchronization type.
|
# Distributed Coach synchronization type.
|
||||||
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
|
||||||
|
|
||||||
|
agent_params.pre_network_filter = InputFilter()
|
||||||
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
|
||||||
ObservationNormalizationFilter(name='normalize_observation'))
|
ObservationNormalizationFilter(name='normalize_observation'))
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
import threading
|
import threading
|
||||||
import pickle
|
import pickle
|
||||||
@@ -102,6 +102,14 @@ class SharedRunningStats(ABC):
|
|||||||
def set_session(self, sess):
|
def set_session(self, sess):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NumpySharedRunningStats(SharedRunningStats):
|
class NumpySharedRunningStats(SharedRunningStats):
|
||||||
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
def __init__(self, name, epsilon=1e-2, pubsub_params=None):
|
||||||
@@ -156,4 +164,21 @@ class NumpySharedRunningStats(SharedRunningStats):
|
|||||||
# no session for the numpy implementation
|
# no session for the numpy implementation
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||||
|
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f:
|
||||||
|
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
|
||||||
|
|
||||||
|
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||||
|
latest_checkpoint = -1
|
||||||
|
# get all checkpoint files
|
||||||
|
for fname in os.listdir(checkpoint_dir):
|
||||||
|
path = os.path.join(checkpoint_dir, fname)
|
||||||
|
if os.path.isdir(path):
|
||||||
|
continue
|
||||||
|
checkpoint_id = int(fname.split('.')[0])
|
||||||
|
if checkpoint_id > latest_checkpoint:
|
||||||
|
latest_checkpoint = checkpoint_id
|
||||||
|
|
||||||
|
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f:
|
||||||
|
temp_running_observation_stats = pickle.load(f)
|
||||||
|
self.__dict__.update(temp_running_observation_stats)
|
||||||
|
|||||||
Reference in New Issue
Block a user