mirror of
https://github.com/gryf/coach.git
synced 2026-03-01 14:15:46 +01:00
Save filters' internal state (#127)
* save filters internal state * moving the restore to be made from within NumpyRunningStats
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#
|
||||
|
||||
import copy
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Union, List
|
||||
@@ -25,12 +26,13 @@ from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class Filter(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Called from reset() and implements the reset logic for the filter.
|
||||
:param name: the filter's name
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
@@ -64,14 +66,39 @@ class Filter(object):
|
||||
"""
|
||||
pass
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_id: the checkpoint's ID
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_name(self, name: str) -> None:
|
||||
"""
|
||||
Set the filter's name
|
||||
:param name: the filter's name
|
||||
:return: None
|
||||
"""
|
||||
self.name = name
|
||||
|
||||
|
||||
class OutputFilter(Filter):
|
||||
"""
|
||||
An output filter is a module that filters the output from an agent to the environment.
|
||||
"""
|
||||
def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None,
|
||||
is_a_reference_filter: bool=False):
|
||||
super().__init__()
|
||||
is_a_reference_filter: bool=False, name=None):
|
||||
super().__init__(name)
|
||||
|
||||
if action_filters is None:
|
||||
action_filters = OrderedDict([])
|
||||
@@ -194,6 +221,15 @@ class OutputFilter(Filter):
|
||||
"""
|
||||
del self._action_filters[filter_name]
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||
"""
|
||||
Currently not in use for OutputFilter.
|
||||
:param checkpoint_dir:
|
||||
:param checkpoint_id:
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class NoOutputFilter(OutputFilter):
|
||||
"""
|
||||
@@ -209,8 +245,8 @@ class InputFilter(Filter):
|
||||
"""
|
||||
def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None,
|
||||
reward_filters: Dict[str, 'RewardFilter']=None,
|
||||
is_a_reference_filter: bool=False):
|
||||
super().__init__()
|
||||
is_a_reference_filter: bool=False, name=None):
|
||||
super().__init__(name)
|
||||
if observation_filters is None:
|
||||
observation_filters = {}
|
||||
if reward_filters is None:
|
||||
@@ -299,7 +335,6 @@ class InputFilter(Filter):
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
def get_filtered_observation_space(self, observation_name: str,
|
||||
input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
"""
|
||||
@@ -409,12 +444,47 @@ class InputFilter(Filter):
|
||||
"""
|
||||
del self._reward_filters[filter_name]
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir, checkpoint_id):
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:param checkpoint_id: the checkpoint's ID
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||
if self.name is not None:
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name), checkpoint_id)
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
filter.save_state_to_checkpoint(os.path.join(checkpoint_dir, 'observation_filters', observation_name,
|
||||
filter_name), checkpoint_id)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir)->None:
|
||||
"""
|
||||
Save the filter's internal state to a checkpoint to file, so that it can be later restored.
|
||||
:param checkpoint_dir: the directory in which to save the filter
|
||||
:return: None
|
||||
"""
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, 'filters')
|
||||
if self.name is not None:
|
||||
checkpoint_dir = os.path.join(checkpoint_dir, self.name)
|
||||
for filter_name, filter in self._reward_filters.items():
|
||||
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'reward_filters', filter_name))
|
||||
|
||||
for observation_name, filters_dict in self._observation_filters.items():
|
||||
for filter_name, filter in filters_dict.items():
|
||||
filter.restore_state_from_checkpoint(os.path.join(checkpoint_dir, 'observation_filters',
|
||||
observation_name, filter_name))
|
||||
|
||||
|
||||
class NoInputFilter(InputFilter):
|
||||
"""
|
||||
Creates an empty input filter. Used only for readability when creating the presets
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(is_a_reference_filter=False)
|
||||
super().__init__(is_a_reference_filter=False, name='no_input_filter')
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import pickle
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@@ -79,3 +81,12 @@ class ObservationNormalizationFilter(ObservationFilter):
|
||||
self.running_observation_stats.set_params(shape=input_observation_space.shape,
|
||||
clip_values=(self.clip_min, self.clip_max))
|
||||
return input_observation_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
|
||||
self.running_observation_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
|
||||
def restore_state_from_checkpoint(self, checkpoint_dir: str):
|
||||
self.running_observation_stats.restore_state_from_checkpoint(checkpoint_dir)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -74,3 +74,9 @@ class RewardNormalizationFilter(RewardFilter):
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
return input_reward_space
|
||||
|
||||
def save_state_to_checkpoint(self, checkpoint_dir: str, checkpoint_id: int):
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
os.makedirs(checkpoint_dir)
|
||||
|
||||
self.running_rewards_stats.save_state_to_checkpoint(checkpoint_dir, checkpoint_id)
|
||||
Reference in New Issue
Block a user