1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 11:10:20 +01:00

Save filters' internal state (#127)

* save filters internal state

* moving the restore to be made from within NumpyRunningStats
This commit is contained in:
Gal Leibovich
2018-11-20 17:21:48 +02:00
committed by GitHub
parent 67eb9e4c28
commit a112ee69f6
13 changed files with 173 additions and 14 deletions

View File

@@ -15,6 +15,7 @@
# #
import copy import copy
import os
import random import random
from collections import OrderedDict from collections import OrderedDict
from typing import Dict, List, Union, Tuple from typing import Dict, List, Union, Tuple
@@ -108,8 +109,12 @@ class Agent(AgentInterface):
# filters # filters
self.input_filter = self.ap.input_filter self.input_filter = self.ap.input_filter
self.input_filter.set_name('input_filter')
self.output_filter = self.ap.output_filter self.output_filter = self.ap.output_filter
self.output_filter.set_name('output_filter')
self.pre_network_filter = self.ap.pre_network_filter self.pre_network_filter = self.ap.pre_network_filter
self.pre_network_filter.set_name('pre_network_filter')
device = self.replicated_device if self.replicated_device else self.worker_device device = self.replicated_device if self.replicated_device else self.worker_device
# TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf # TODO-REMOVE This is a temporary flow dividing to 3 modes. To be converged to a single flow once distributed tf
@@ -923,7 +928,26 @@ class Agent(AgentInterface):
:param checkpoint_id: the id of the checkpoint :param checkpoint_id: the id of the checkpoint
:return: None :return: None
""" """
pass checkpoint_dir = os.path.join(self.ap.task_parameters.checkpoint_save_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.output_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
self.pre_network_filter.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
def restore_checkpoint(self, checkpoint_dir: str) -> None:
"""
Allows agents to store additional information when saving checkpoints.
:param checkpoint_id: the id of the checkpoint
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir,
*(self.full_name_id.split('/'))) # adds both level name and agent name
self.input_filter.restore_state_from_checkpoint(checkpoint_dir)
self.pre_network_filter.restore_state_from_checkpoint(checkpoint_dir)
# no output filters currently have an internal state to restore
# self.output_filter.restore_state_from_checkpoint(checkpoint_dir)
def sync(self) -> None: def sync(self) -> None:
""" """

View File

@@ -392,6 +392,9 @@ class CompositeAgent(AgentInterface):
def save_checkpoint(self, checkpoint_id: int) -> None: def save_checkpoint(self, checkpoint_id: int) -> None:
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
def restore_checkpoint(self, checkpoint_dir: str) -> None:
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
def set_incoming_directive(self, action: ActionType) -> None: def set_incoming_directive(self, action: ActionType) -> None:
self.incoming_action = action self.incoming_action = action
if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection): if isinstance(self.decision_policy, SingleDecider) and isinstance(self.in_action_space, AgentSelection):

View File

@@ -204,5 +204,6 @@ class NECAgent(ValueOptimizationAgent):
actions, discounted_rewards) actions, discounted_rewards)
def save_checkpoint(self, checkpoint_id): def save_checkpoint(self, checkpoint_id):
super().save_checkpoint(checkpoint_id)
with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f: with open(os.path.join(self.ap.task_parameters.checkpoint_save_dir, str(checkpoint_id) + '.dnd'), 'wb') as f:
pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL) pickle.dump(self.networks['main'].online_network.output_heads[0].DND, f, pickle.HIGHEST_PROTOCOL)

View File

@@ -128,3 +128,11 @@ class TFSharedRunningStats(SharedRunningStats):
return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch}) return self.sess.run(self.clipped_obs, feed_dict={self.raw_obs: batch})
else: else:
return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch}) return self.sess.run(self.normalized_obs, feed_dict={self.raw_obs: batch})
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
# the stats are part of the TF graph - no need to explicitly save anything
pass
def restore_state_from_checkpoint(self, checkpoint_dir: str):
# the stats are part of the TF graph - no need to explicitly restore anything
pass

View File

@@ -506,7 +506,7 @@ class AgentParameters(Parameters):
self.input_filter = None self.input_filter = None
self.output_filter = None self.output_filter = None
self.pre_network_filter = NoInputFilter() self.pre_network_filter = NoInputFilter()
self.full_name_id = None # TODO: do we really want to hold this parameter here? self.full_name_id = None
self.name = None self.name = None
self.is_a_highest_level_agent = True self.is_a_highest_level_agent = True
self.is_a_lowest_level_agent = True self.is_a_lowest_level_agent = True

View File

@@ -118,7 +118,7 @@ def handle_distributed_coach_tasks(graph_manager, args):
) )
def handle_distributed_coach_orchestrator(graph_manager, args): def handle_distributed_coach_orchestrator(args):
from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \ from rl_coach.orchestrators.kubernetes_orchestrator import KubernetesParameters, Kubernetes, \
RunTypeParameters RunTypeParameters

View File

@@ -15,6 +15,7 @@
# #
import copy import copy
import os
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union, List from typing import Dict, Union, List
@@ -25,12 +26,13 @@ from rl_coach.utils import force_list
class Filter(object): class Filter(object):
def __init__(self): def __init__(self, name=None):
pass self.name = name
def reset(self) -> None: def reset(self) -> None:
""" """
Called from reset() and implements the reset logic for the filter. Called from reset() and implements the reset logic for the filter.
:param name: the filter's name
:return: None :return: None
""" """
pass pass
@@ -64,14 +66,39 @@ class Filter(object):
""" """
pass pass
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None:
"""
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter
:param checkpoint_id: the checkpoint's ID
:return: None
"""
pass
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
"""
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter
:return: None
"""
pass
def set_name(self, name: str) -> None:
"""
Set the filter's name
:param name: the filter's name
:return: None
"""
self.name = name
class OutputFilter(Filter): class OutputFilter(Filter):
""" """
An output filter is a module that filters the output from an agent to the environment. An output filter is a module that filters the output from an agent to the environment.
""" """
def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None, def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None,
is_a_reference_filter: bool=False): is_a_reference_filter: bool=False, name=None):
super().__init__() super().__init__(name)
if action_filters is None: if action_filters is None:
action_filters = OrderedDict([]) action_filters = OrderedDict([])
@@ -194,6 +221,15 @@ class OutputFilter(Filter):
""" """
del self._action_filters[filter_name] del self._action_filters[filter_name]
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
"""
Currently not in use for OutputFilter.
:param checkpoint_dir:
:param checkpoint_id:
:return:
"""
pass
class NoOutputFilter(OutputFilter): class NoOutputFilter(OutputFilter):
""" """
@@ -209,8 +245,8 @@ class InputFilter(Filter):
""" """
def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None, def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None,
reward_filters: Dict[str, 'RewardFilter']=None, reward_filters: Dict[str, 'RewardFilter']=None,
is_a_reference_filter: bool=False): is_a_reference_filter: bool=False, name=None):
super().__init__() super().__init__(name)
if observation_filters is None: if observation_filters is None:
observation_filters = {} observation_filters = {}
if reward_filters is None: if reward_filters is None:
@@ -299,7 +335,6 @@ class InputFilter(Filter):
return filtered_data return filtered_data
def get_filtered_observation_space(self, observation_name: str, def get_filtered_observation_space(self, observation_name: str,
input_observation_space: ObservationSpace) -> ObservationSpace: input_observation_space: ObservationSpace) -> ObservationSpace:
""" """
@@ -409,12 +444,47 @@ class InputFilter(Filter):
""" """
del self._reward_filters[filter_name] del self._reward_filters[filter_name]
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
"""
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter
:param checkpoint_id: the checkpoint's ID
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
if self.name is not None:
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
for filter_name, filter in self._reward_filters.items():
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id)
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name,
filter_name), checkpoint_id)
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
"""
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
:param checkpoint_dir: the directory in which to save the filter
:return: None
"""
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
if self.name is not None:
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
for filter_name, filter in self._reward_filters.items():
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name))
for observation_name, filters_dict in self._observation_filters.items():
for filter_name, filter in filters_dict.items():
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters',
observation_name, filter_name))
class NoInputFilter(InputFilter): class NoInputFilter(InputFilter):
""" """
Creates an empty input filter. Used only for readability when creating the presets Creates an empty input filter. Used only for readability when creating the presets
""" """
def __init__(self): def __init__(self):
super().__init__(is_a_reference_filter=False) super().__init__(is_a_reference_filter=False, name='no_input_filter')

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import pickle
from typing import List from typing import List
import numpy as np import numpy as np
@@ -79,3 +81,12 @@ class ObservationNormalizationFilter(ObservationFilter):
self.running_observation_stats.set_params(shape=input_observation_space.shape, self.running_observation_stats.set_params(shape=input_observation_space.shape,
clip_values=(self.clip_min, self.clip_max)) clip_values=(self.clip_min, self.clip_max))
return input_observation_space return input_observation_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
def restore_state_from_checkpoint(self, checkpoint_dir: str):
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir)

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
import numpy as np import numpy as np
@@ -74,3 +74,9 @@ class RewardNormalizationFilter(RewardFilter):
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace: def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
return input_reward_space return input_reward_space
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)

View File

@@ -565,7 +565,7 @@ class GraphManager(object):
self.verify_graph_was_created() self.verify_graph_was_created()
# TODO: find better way to load checkpoints that were saved with a global network into the online network # TODO: find better way to load checkpoints that were saved with a global network into the online network
if hasattr(self.task_parameters, 'checkpoint_restore_dir') and self.task_parameters.checkpoint_restore_dir: if self.task_parameters.checkpoint_restore_dir:
checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir) checkpoint = get_checkpoint_state(self.task_parameters.checkpoint_restore_dir)
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path)) screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
@@ -577,6 +577,8 @@ class GraphManager(object):
else: else:
raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type)) raise ValueError('Invalid framework {}'.format(self.task_parameters.framework_type))
[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]
def occasionally_save_checkpoint(self): def occasionally_save_checkpoint(self):
# only the chief process saves checkpoints # only the chief process saves checkpoints
if self.task_parameters.checkpoint_save_secs \ if self.task_parameters.checkpoint_save_secs \

View File

@@ -255,6 +255,13 @@ class LevelManager(EnvironmentInterface):
""" """
[agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()] [agent.save_checkpoint(checkpoint_id) for agent in self.agents.values()]
def restore_checkpoint(self, checkpoint_dir: str) -> None:
"""
Restores checkpoints of the networks of all agents
:return: None
"""
[agent.restore_checkpoint(checkpoint_dir) for agent in self.agents.values()]
def sync(self) -> None: def sync(self) -> None:
""" """
Sync the networks of the agents with the global network parameters Sync the networks of the agents with the global network parameters

View File

@@ -5,6 +5,7 @@ from rl_coach.core_types import TrainingSteps, EnvironmentEpisodes, EnvironmentS
from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2 from rl_coach.environments.gym_environment import GymVectorEnvironment, mujoco_v2
from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters from rl_coach.exploration_policies.additive_noise import AdditiveNoiseParameters
from rl_coach.exploration_policies.e_greedy import EGreedyParameters from rl_coach.exploration_policies.e_greedy import EGreedyParameters
from rl_coach.filters.filter import InputFilter
from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter from rl_coach.filters.observation.observation_normalization_filter import ObservationNormalizationFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters from rl_coach.graph_managers.graph_manager import ScheduleParameters
@@ -47,6 +48,7 @@ agent_params.algorithm.num_steps_between_copying_online_weights_to_target = Envi
# Distributed Coach synchronization type. # Distributed Coach synchronization type.
agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC agent_params.algorithm.distributed_coach_synchronization_type = DistributedCoachSynchronizationType.SYNC
agent_params.pre_network_filter = InputFilter()
agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation', agent_params.pre_network_filter.add_observation_filter('observation', 'normalize_observation',
ObservationNormalizationFilter(name='normalize_observation')) ObservationNormalizationFilter(name='normalize_observation'))

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import threading import threading
import pickle import pickle
@@ -102,6 +102,14 @@ class SharedRunningStats(ABC):
def set_session(self, sess): def set_session(self, sess):
pass pass
@abstractmethod
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
pass
@abstractmethod
def restore_state_from_checkpoint(self, checkpoint_dir: str):
pass
class NumpySharedRunningStats(SharedRunningStats): class NumpySharedRunningStats(SharedRunningStats):
def __init__(self, name, epsilon=1e-2, pubsub_params=None): def __init__(self, name, epsilon=1e-2, pubsub_params=None):
@@ -156,4 +164,21 @@ class NumpySharedRunningStats(SharedRunningStats):
# no session for the numpy implementation # no session for the numpy implementation
pass pass
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
with open(os.path.join(checkpoint_dir, str(checkpoint_id) + '.srs'), 'wb') as f:
pickle.dump(self.__dict__, f, pickle.HIGHEST_PROTOCOL)
def restore_state_from_checkpoint(self, checkpoint_dir: str):
latest_checkpoint = -1
# get all checkpoint files
for fname in os.listdir(checkpoint_dir):
path = os.path.join(checkpoint_dir, fname)
if os.path.isdir(path):
continue
checkpoint_id = int(fname.split('.')[0])
if checkpoint_id > latest_checkpoint:
latest_checkpoint = checkpoint_id
with open(os.path.join(checkpoint_dir, str(latest_checkpoint) + '.srs'), 'rb') as f:
temp_running_observation_stats = pickle.load(f)
self.__dict__.update(temp_running_observation_stats)