diff --git a/rl_coach/agents/agent.py b/rl_coach/agents/agent.py index a9ecbb2..af55b5f 100644 --- a/rl_coach/agents/agent.py +++ b/rl_coach/agents/agent.py @@ -15,6 +15,7 @@ # import copy +import os import random from collections import OrderedDict from typing import Dict, List, Union, Tuple @@ -108,8 +109,12 @@ class Agent(AgentInterface): # filters self.input_filter = self.ap.input_filter + self.input_filter.set_name('input_filter') self.output_filter = self.ap.output_filter + self.output_filter.set_name('output_filter') self.pre_network_filter = self.ap.pre_network_filter + self.pre_network_filter.set_name('pre_network_filter') + device = self.replicated_device if self.replicated_device else self.worker_device # TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf @@ -923,7 +928,26 @@ class Agent(AgentInterface): :param checkpoint_id: the id of the checkpoint :return: None """ - pass + checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir, + *(self.full_name_id.split('/'))) # adds both level name and agent name + self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + + def restore_checkpoint(self, checkpoint_dir: str) -> None: + """ + Allows agents to store additional information when saving checkpoints. + + :param checkpoint_id: the id of the checkpoint + :return: None + """ + checkpoint_dir = os.path.join(checkpoint_dir, + *(self.full_name_id.split('/'))) # adds both level name and agent name + self.input_filter.restore_state_from_checkpoint(checkpoint_dir) + self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir) + + # no output filters currently have an internal state to restore + # self.output_filter.restore_state_from_checkpoint(checkpoint_dir) def sync(self) -> None: """ diff --git a/rl_coach/agents/composite_agent.py b/rl_coach/agents/composite_agent.py index 42dcad9..354c36a 100644 --- a/rl_coach/agents/composite_agent.py +++ b/rl_coach/agents/composite_agent.py @@ -392,6 +392,9 @@ class CompositeAgent(AgentInterface): def save_checkpoint(self, checkpoint_id: int) -> None: [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + def restore_checkpoint(self, checkpoint_dir: str) -> None: + [agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()] + def set_incoming_directive(self, action: ActionType) -> None: self.incoming_action = action if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection): diff --git a/rl_coach/agents/nec_agent.py b/rl_coach/agents/nec_agent.py index a576dde..4baf08a 100644 --- a/rl_coach/agents/nec_agent.py +++ b/rl_coach/agents/nec_agent.py @@ -204,5 +204,6 @@ class NECAgent(ValueOptimizationAgent): actions, discounted_rewards) def save_checkpoint(self, checkpoint_id): + super().save_checkpoint(checkpoint_id) with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) diff --git a/rl_coach/architectures/tensorflow_components/shared_variables.py b/rl_coach/architectures/tensorflow_components/shared_variables.py index 33a2c05..4a38e3e 100644 --- a/rl_coach/architectures/tensorflow_components/shared_variables.py +++ b/rl_coach/architectures/tensorflow_components/shared_variables.py @@ -128,3 +128,11 @@ class TFSharedRunningStats(SharedRunningStats): return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) else: return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) + + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + # the stats are part of the TF graph - no need to explicitly save anything + pass + + def restore_state_from_checkpoint(self, checkpoint_dir: str): + # the stats are part of the TF graph - no need to explicitly restore anything + pass diff --git a/rl_coach/base_parameters.py b/rl_coach/base_parameters.py index 2519327..5ba65f8 100644 --- a/rl_coach/base_parameters.py +++ b/rl_coach/base_parameters.py @@ -506,7 +506,7 @@ class AgentParameters(Parameters): self.input_filter = None self.output_filter = None self.pre_network_filter = NoInputFilter() - self.full_name_id = None # TODO: do we really want to hold this parameter here? + self.full_name_id = None self.name = None self.is_a_highest_level_agent = True self.is_a_lowest_level_agent = True diff --git a/rl_coach/coach.py b/rl_coach/coach.py index 1994c68..b5c23a2 100644 --- a/rl_coach/coach.py +++ b/rl_coach/coach.py @@ -118,7 +118,7 @@ def handle_distributed_coach_tasks(graph_manager, args): ) -def handle_distributed_coach_orchestrator(graph_manager, args): +def handle_distributed_coach_orchestrator(args): from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \ RunTypeParameters diff --git a/rl_coach/filters/filter.py b/rl_coach/filters/filter.py index dbf59f2..28a195f 100644 --- a/rl_coach/filters/filter.py +++ b/rl_coach/filters/filter.py @@ -15,6 +15,7 @@ # import copy +import os from collections import OrderedDict from copy import deepcopy from typing import Dict, Union, List @@ -25,12 +26,13 @@ from rl_coach.utils import force_list class Filter(object): - def __init__(self): - pass + def __init__(self, name=None): + self.name = name def reset(self) -> None: """ Called from reset() and implements the reset logic for the filter. + :param name: the filter's name :return: None """ pass @@ -64,14 +66,39 @@ class Filter(object): """ pass + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None: + """ + Save the filter's internal state to a checkpoint to file, so that it can be later restored. + :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_id: the checkpoint's ID + :return: None + """ + pass + + def restore_state_from_checkpoint(self, checkpoint_dir)->None: + """ + Save the filter's internal state to a checkpoint to file, so that it can be later restored. + :param checkpoint_dir: the directory in which to save the filter + :return: None + """ + pass + + def set_name(self, name: str) -> None: + """ + Set the filter's name + :param name: the filter's name + :return: None + """ + self.name = name + class OutputFilter(Filter): """ An output filter is a module that filters the output from an agent to the environment. """ def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None, - is_a_reference_filter: bool=False): - super().__init__() + is_a_reference_filter: bool=False, name=None): + super().__init__(name) if action_filters is None: action_filters = OrderedDict([]) @@ -194,6 +221,15 @@ class OutputFilter(Filter): """ del self._action_filters[filter_name] + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + """ + Currently not in use for OutputFilter. + :param checkpoint_dir: + :param checkpoint_id: + :return: + """ + pass + class NoOutputFilter(OutputFilter): """ @@ -209,8 +245,8 @@ class InputFilter(Filter): """ def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None, reward_filters: Dict[str, 'RewardFilter']=None, - is_a_reference_filter: bool=False): - super().__init__() + is_a_reference_filter: bool=False, name=None): + super().__init__(name) if observation_filters is None: observation_filters = {} if reward_filters is None: @@ -299,7 +335,6 @@ class InputFilter(Filter): return filtered_data - def get_filtered_observation_space(self, observation_name: str, input_observation_space: ObservationSpace) -> ObservationSpace: """ @@ -409,12 +444,47 @@ class InputFilter(Filter): """ del self._reward_filters[filter_name] + def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id): + """ + Save the filter's internal state to a checkpoint to file, so that it can be later restored. + :param checkpoint_dir: the directory in which to save the filter + :param checkpoint_id: the checkpoint's ID + :return: None + """ + checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + if self.name is not None: + checkpoint_dir = os.path.join(checkpoint_dir, self.name) + for filter_name, filter in self._reward_filters.items(): + filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id) + + for observation_name, filters_dict in self._observation_filters.items(): + for filter_name, filter in filters_dict.items(): + filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name, + filter_name), checkpoint_id) + + def restore_state_from_checkpoint(self, checkpoint_dir)->None: + """ + Save the filter's internal state to a checkpoint to file, so that it can be later restored. + :param checkpoint_dir: the directory in which to save the filter + :return: None + """ + checkpoint_dir = os.path.join(checkpoint_dir, 'filters') + if self.name is not None: + checkpoint_dir = os.path.join(checkpoint_dir, self.name) + for filter_name, filter in self._reward_filters.items(): + filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name)) + + for observation_name, filters_dict in self._observation_filters.items(): + for filter_name, filter in filters_dict.items(): + filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', + observation_name, filter_name)) + class NoInputFilter(InputFilter): """ Creates an empty input filter. Used only for readability when creating the presets """ def __init__(self): - super().__init__(is_a_reference_filter=False) + super().__init__(is_a_reference_filter=False, name='no_input_filter') diff --git a/rl_coach/filters/observation/observation_normalization_filter.py b/rl_coach/filters/observation/observation_normalization_filter.py index 796ef31..219fc0e 100644 --- a/rl_coach/filters/observation/observation_normalization_filter.py +++ b/rl_coach/filters/observation/observation_normalization_filter.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import pickle from typing import List import numpy as np @@ -79,3 +81,12 @@ class ObservationNormalizationFilter(ObservationFilter): self.running_observation_stats.set_params(shape=input_observation_space.shape, clip_values=(self.clip_min, self.clip_max)) return input_observation_space + + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) + + def restore_state_from_checkpoint(self, checkpoint_dir: str): + self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir) diff --git a/rl_coach/filters/reward/reward_normalization_filter.py b/rl_coach/filters/reward/reward_normalization_filter.py index daf5562..c6c489c 100644 --- a/rl_coach/filters/reward/reward_normalization_filter.py +++ b/rl_coach/filters/reward/reward_normalization_filter.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os import numpy as np @@ -74,3 +74,9 @@ class RewardNormalizationFilter(RewardFilter): def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: return input_reward_space + + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + if not os.path.exists(checkpoint_dir): + os.makedirs(checkpoint_dir) + + self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id) \ No newline at end of file diff --git a/rl_coach/graph_managers/graph_manager.py b/rl_coach/graph_managers/graph_manager.py index b68f55c..05153c1 100644 --- a/rl_coach/graph_managers/graph_manager.py +++ b/rl_coach/graph_managers/graph_manager.py @@ -565,7 +565,7 @@ class GraphManager(object): self.verify_graph_was_created() # TODO: find better way to load checkpoints that were saved with a global network into the online network - if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir: + if self.task_parameters.checkpoint_restore_dir: checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) @@ -577,6 +577,8 @@ class GraphManager(object): else: raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type)) + [manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers] + def occasionally_save_checkpoint(self): # only the chief process saves checkpoints if self.task_parameters.checkpoint_save_secs \ diff --git a/rl_coach/level_manager.py b/rl_coach/level_manager.py index 312a5be..1a344f1 100644 --- a/rl_coach/level_manager.py +++ b/rl_coach/level_manager.py @@ -255,6 +255,13 @@ class LevelManager(EnvironmentInterface): """ [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] + def restore_checkpoint(self, checkpoint_dir: str) -> None: + """ + Restores checkpoints of the networks of all agents + :return: None + """ + [agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()] + def sync(self) -> None: """ Sync the networks of the agents with the global network parameters diff --git a/rl_coach/presets/CartPole_ClippedPPO.py b/rl_coach/presets/CartPole_ClippedPPO.py index 50014c4..4911ae5 100644 --- a/rl_coach/presets/CartPole_ClippedPPO.py +++ b/rl_coach/presets/CartPole_ClippedPPO.py @@ -5,6 +5,7 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters +from rl_coach.filters.filter import InputFilter from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.graph_manager import ScheduleParameters @@ -47,6 +48,7 @@ agent_params.algorithm.num_steps_between_copying_online_weights_to_target = Envi # Distributed Coach synchronization type. agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC +agent_params.pre_network_filter = InputFilter() agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', ObservationNormalizationFilter(name='normalize_observation')) diff --git a/rl_coach/utilities/shared_running_stats.py b/rl_coach/utilities/shared_running_stats.py index c2f1dd4..c76f232 100644 --- a/rl_coach/utilities/shared_running_stats.py +++ b/rl_coach/utilities/shared_running_stats.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os from abc import ABC, abstractmethod import threading import pickle @@ -102,6 +102,14 @@ class SharedRunningStats(ABC): def set_session(self, sess): pass + @abstractmethod + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + pass + + @abstractmethod + def restore_state_from_checkpoint(self, checkpoint_dir: str): + pass + class NumpySharedRunningStats(SharedRunningStats): def __init__(self, name, epsilon=1e-2, pubsub_params=None): @@ -156,4 +164,21 @@ class NumpySharedRunningStats(SharedRunningStats): # no session for the numpy implementation pass + def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int): + with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f: + pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL) + def restore_state_from_checkpoint(self, checkpoint_dir: str): + latest_checkpoint = -1 + # get all checkpoint files + for fname in os.listdir(checkpoint_dir): + path = os.path.join(checkpoint_dir, fname) + if os.path.isdir(path): + continue + checkpoint_id = int(fname.split('.')[0]) + if checkpoint_id > latest_checkpoint: + latest_checkpoint = checkpoint_id + + with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f: + temp_running_observation_stats = pickle.load(f) + self.__dict__.update(temp_running_observation_stats)