mirror of
https://github.com/gryf/coach.git
synced 2026-03-12 20:45:55 +01:00
pre-release 0.10.0
This commit is contained in:
70
rl_coach/filters/README.md
Normal file
70
rl_coach/filters/README.md
Normal file
@@ -0,0 +1,70 @@
|
||||
A custom observation filter implementation should look like this:
|
||||
|
||||
```bash
|
||||
from coach.filters.filter import ObservationFilter
|
||||
|
||||
class CustomFilter(ObservationFilter):
|
||||
def __init__(self):
|
||||
...
|
||||
def filter(self, env_response: EnvResponse) -> EnvResponse:
|
||||
...
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
...
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
...
|
||||
def reset(self):
|
||||
...
|
||||
```
|
||||
|
||||
or for reward filters:
|
||||
```bash
|
||||
from coach.filters.filter import RewardFilter
|
||||
|
||||
class CustomFilter(ObservationFilter):
|
||||
def __init__(self):
|
||||
...
|
||||
def filter(self, env_response: EnvResponse) -> EnvResponse:
|
||||
...
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
...
|
||||
def reset(self):
|
||||
...
|
||||
```
|
||||
|
||||
To create a stack of filters:
|
||||
|
||||
```bash
|
||||
from coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter, RescaleInterpolationType
|
||||
from coach.filters.observation.observation_crop_filter import ObservationCropFilter
|
||||
from coach.filters.reward.reward_clipping_filter import RewardClippingFilter
|
||||
from environments.environment_interface import ObservationSpace
|
||||
import numpy as np
|
||||
from core_types import EnvResponse
|
||||
from filters.filter import InputFilter
|
||||
from collections import OrderedDict
|
||||
|
||||
env_response = EnvResponse({'observation': np.ones([210, 160])}, reward=100, game_over=False)
|
||||
|
||||
rescale = ObservationRescaleToSizeFilter(
|
||||
output_observation_space=ObservationSpace(np.array([110, 84])),
|
||||
rescaling_interpolation_type=RescaleInterpolationType.BILINEAR
|
||||
)
|
||||
|
||||
crop = ObservationCropFilter(
|
||||
crop_low=np.array([16, 0]),
|
||||
crop_high=np.array([100, 84])
|
||||
)
|
||||
|
||||
clip = RewardClippingFilter(
|
||||
clipping_low=-1,
|
||||
clipping_high=1
|
||||
)
|
||||
|
||||
input_filter = InputFilter(
|
||||
observation_filters=OrderedDict([('rescale', rescale), ('crop', crop)]),
|
||||
reward_filters=OrderedDict([('clip', clip)])
|
||||
)
|
||||
|
||||
result = input_filter.filter(env_response)
|
||||
|
||||
```
|
||||
0
rl_coach/filters/__init__.py
Normal file
0
rl_coach/filters/__init__.py
Normal file
0
rl_coach/filters/action/__init__.py
Normal file
0
rl_coach/filters/action/__init__.py
Normal file
69
rl_coach/filters/action/action_filter.py
Normal file
69
rl_coach/filters/action/action_filter.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.spaces import ActionSpace
|
||||
|
||||
from rl_coach.core_types import ActionType
|
||||
from rl_coach.filters.filter import Filter
|
||||
|
||||
|
||||
class ActionFilter(Filter):
|
||||
def __init__(self, input_action_space: ActionSpace=None):
|
||||
self.input_action_space = input_action_space
|
||||
self.output_action_space = None
|
||||
super().__init__()
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> ActionSpace:
|
||||
"""
|
||||
This function should contain the logic for getting the unfiltered action space
|
||||
:param output_action_space: the output action space
|
||||
:return: the unfiltered action space
|
||||
"""
|
||||
return output_action_space
|
||||
|
||||
def validate_output_action_space(self, output_action_space: ActionSpace):
|
||||
"""
|
||||
A function that implements validation of the output action space
|
||||
:param output_action_space: the input action space
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_output_action(self, action: ActionType):
|
||||
"""
|
||||
A function that verifies that the given action is in the expected output action space
|
||||
:param action: an action to validate
|
||||
:return: None
|
||||
"""
|
||||
if not self.output_action_space.val_matches_space_definition(action):
|
||||
raise ValueError("The given action ({}) does not match the action space ({})"
|
||||
.format(action, self.output_action_space))
|
||||
|
||||
def filter(self, action: ActionType) -> ActionType:
|
||||
"""
|
||||
A function that transforms from the agent's action space to the environment's action space
|
||||
:param action: an action to transform
|
||||
:return: transformed action
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def reverse_filter(self, action: ActionType) -> ActionType:
|
||||
"""
|
||||
A function that transforms from the environment's action space to the agent's action space
|
||||
:param action: an action to transform
|
||||
:return: transformed action
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
66
rl_coach/filters/action/attention_discretization.py
Normal file
66
rl_coach/filters/action/attention_discretization.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.filters.action.box_discretization import BoxDiscretization
|
||||
|
||||
from rl_coach.filters.action.partial_discrete_action_space_map import PartialDiscreteActionSpaceMap
|
||||
from rl_coach.spaces import AttentionActionSpace, BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
|
||||
class AttentionDiscretization(PartialDiscreteActionSpaceMap):
|
||||
"""
|
||||
Given a box action space, this is used to discretize the space.
|
||||
The discretization is achieved by creating a grid in the space with num_bins_per_dimension bins per dimension in the
|
||||
space. Each discrete action is mapped to a single sub-box in the BoxActionSpace action space.
|
||||
"""
|
||||
def __init__(self, num_bins_per_dimension: Union[int, List[int]], force_int_bins=False):
|
||||
# we allow specifying either a single number for all dimensions, or a single number per dimension in the target
|
||||
# action space
|
||||
self.num_bins_per_dimension = num_bins_per_dimension
|
||||
|
||||
self.force_int_bins = force_int_bins
|
||||
|
||||
# TODO: this will currently only work for attention spaces with 2 dimensions. generalize it.
|
||||
|
||||
super().__init__()
|
||||
|
||||
def validate_output_action_space(self, output_action_space: AttentionActionSpace):
|
||||
if not isinstance(output_action_space, AttentionActionSpace):
|
||||
raise ValueError("AttentionActionSpace discretization only works with an output space of type AttentionActionSpace. "
|
||||
"The given output space is {}".format(output_action_space))
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: AttentionActionSpace) -> DiscreteActionSpace:
|
||||
if isinstance(self.num_bins_per_dimension, int):
|
||||
self.num_bins_per_dimension = [self.num_bins_per_dimension] * output_action_space.shape[0]
|
||||
|
||||
# create a discrete to linspace map to ease the extraction of attention actions
|
||||
discrete_to_box = BoxDiscretization([n+1 for n in self.num_bins_per_dimension],
|
||||
self.force_int_bins)
|
||||
discrete_to_box.get_unfiltered_action_space(BoxActionSpace(output_action_space.shape,
|
||||
output_action_space.low,
|
||||
output_action_space.high), )
|
||||
|
||||
rows, cols = self.num_bins_per_dimension
|
||||
start_ind = [i * (cols + 1) + j for i in range(rows + 1) if i < rows for j in range(cols + 1) if j < cols]
|
||||
end_ind = [i + cols + 2 for i in start_ind]
|
||||
self.target_actions = [np.array([discrete_to_box.target_actions[start],
|
||||
discrete_to_box.target_actions[end]])
|
||||
for start, end in zip(start_ind, end_ind)]
|
||||
|
||||
return super().get_unfiltered_action_space(output_action_space)
|
||||
70
rl_coach/filters/action/box_discretization.py
Normal file
70
rl_coach/filters/action/box_discretization.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from itertools import product
|
||||
from typing import Union, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.filters.action.partial_discrete_action_space_map import PartialDiscreteActionSpaceMap
|
||||
from rl_coach.spaces import BoxActionSpace, DiscreteActionSpace
|
||||
|
||||
|
||||
class BoxDiscretization(PartialDiscreteActionSpaceMap):
|
||||
"""
|
||||
Given a box action space, this is used to discretize the space.
|
||||
The discretization is achieved by creating a grid in the space with num_bins_per_dimension bins per dimension in the
|
||||
space. Each discrete action is mapped to a single N dimensional action in the BoxActionSpace action space.
|
||||
"""
|
||||
def __init__(self, num_bins_per_dimension: Union[int, List[int]], force_int_bins=False):
|
||||
"""
|
||||
:param num_bins_per_dimension: The number of bins to use for each dimension of the target action space.
|
||||
The bins will be spread out uniformly over this space
|
||||
:param force_int_bins: force the bins to represent only integer actions. for example, if the action space is in
|
||||
the range 0-10 and there are 5 bins, then the bins will be placed at 0, 2, 5, 7, 10,
|
||||
instead of 0, 2.5, 5, 7.5, 10.
|
||||
"""
|
||||
# we allow specifying either a single number for all dimensions, or a single number per dimension in the target
|
||||
# action space
|
||||
self.num_bins_per_dimension = num_bins_per_dimension
|
||||
self.force_int_bins = force_int_bins
|
||||
super().__init__()
|
||||
|
||||
def validate_output_action_space(self, output_action_space: BoxActionSpace):
|
||||
if not isinstance(output_action_space, BoxActionSpace):
|
||||
raise ValueError("BoxActionSpace discretization only works with an output space of type BoxActionSpace. "
|
||||
"The given output space is {}".format(output_action_space))
|
||||
|
||||
if len(self.num_bins_per_dimension) != output_action_space.shape:
|
||||
# TODO: this check is not sufficient. it does not deal with actions spaces with more than one axis
|
||||
raise ValueError("The length of the list of bins per dimension ({}) does not match the number of "
|
||||
"dimensions in the action space ({})"
|
||||
.format(len(self.num_bins_per_dimension), output_action_space))
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: BoxActionSpace) -> DiscreteActionSpace:
|
||||
if isinstance(self.num_bins_per_dimension, int):
|
||||
self.num_bins_per_dimension = np.ones(output_action_space.shape) * self.num_bins_per_dimension
|
||||
|
||||
bins = []
|
||||
for i in range(len(output_action_space.low)):
|
||||
dim_bins = np.linspace(output_action_space.low[i], output_action_space.high[i],
|
||||
self.num_bins_per_dimension[i])
|
||||
if self.force_int_bins:
|
||||
dim_bins = dim_bins.astype(int)
|
||||
bins.append(dim_bins)
|
||||
self.target_actions = [list(action) for action in list(product(*bins))]
|
||||
|
||||
return super().get_unfiltered_action_space(output_action_space)
|
||||
83
rl_coach/filters/action/box_masking.py
Normal file
83
rl_coach/filters/action/box_masking.py
Normal file
@@ -0,0 +1,83 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import BoxActionSpace
|
||||
|
||||
from rl_coach.core_types import ActionType
|
||||
from rl_coach.filters.action.action_filter import ActionFilter
|
||||
|
||||
|
||||
class BoxMasking(ActionFilter):
|
||||
"""
|
||||
Masks a box action space by allowing only selecting a subset of the space
|
||||
For example,
|
||||
- the target action space has actions of shape 1 with values between 10 and 32
|
||||
- we mask the target action space so that only the action 20 to 25 can be chosen
|
||||
The actions will be between 0 to 5 and the mapping will add an offset of 20 to the incoming actions
|
||||
The shape of the source and target action spaces is always the same
|
||||
"""
|
||||
def __init__(self,
|
||||
masked_target_space_low: Union[None, int, float, np.ndarray],
|
||||
masked_target_space_high: Union[None, int, float, np.ndarray]):
|
||||
"""
|
||||
:param masked_target_space_low: the lowest values that can be chosen in the target action space
|
||||
:param masked_target_space_high: the highest values that can be chosen in the target action space
|
||||
"""
|
||||
self.masked_target_space_low = masked_target_space_low
|
||||
self.masked_target_space_high = masked_target_space_high
|
||||
self.offset = masked_target_space_low
|
||||
super().__init__()
|
||||
|
||||
def set_masking(self, masked_target_space_low: Union[None, int, float, np.ndarray],
|
||||
masked_target_space_high: Union[None, int, float, np.ndarray]):
|
||||
self.masked_target_space_low = masked_target_space_low
|
||||
self.masked_target_space_high = masked_target_space_high
|
||||
self.offset = masked_target_space_low
|
||||
if self.output_action_space:
|
||||
self.validate_output_action_space(self.output_action_space)
|
||||
self.input_action_space = BoxActionSpace(self.output_action_space.shape,
|
||||
low=0,
|
||||
high=self.masked_target_space_high - self.masked_target_space_low)
|
||||
|
||||
def validate_output_action_space(self, output_action_space: BoxActionSpace):
|
||||
if not isinstance(output_action_space, BoxActionSpace):
|
||||
raise ValueError("BoxActionSpace discretization only works with an output space of type BoxActionSpace. "
|
||||
"The given output space is {}".format(output_action_space))
|
||||
if self.masked_target_space_low is None or self.masked_target_space_high is None:
|
||||
raise ValueError("The masking target space size was not set. Please call set_masking.")
|
||||
if not (np.all(output_action_space.low <= self.masked_target_space_low)
|
||||
and np.all(self.masked_target_space_low <= output_action_space.high)):
|
||||
raise ValueError("The low values for masking the action space ({}) are not within the range of the "
|
||||
"target space (low = {}, high = {})"
|
||||
.format(self.masked_target_space_low, output_action_space.low, output_action_space.high))
|
||||
if not (np.all(output_action_space.low <= self.masked_target_space_high)
|
||||
and np.all(self.masked_target_space_high <= output_action_space.high)):
|
||||
raise ValueError("The high values for masking the action space ({}) are not within the range of the "
|
||||
"target space (low = {}, high = {})"
|
||||
.format(self.masked_target_space_high, output_action_space.low, output_action_space.high))
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: BoxActionSpace) -> BoxActionSpace:
|
||||
self.output_action_space = output_action_space
|
||||
self.input_action_space = BoxActionSpace(output_action_space.shape,
|
||||
low=0,
|
||||
high=self.masked_target_space_high - self.masked_target_space_low)
|
||||
return self.input_action_space
|
||||
|
||||
def filter(self, action: ActionType) -> ActionType:
|
||||
return action + self.offset
|
||||
32
rl_coach/filters/action/full_discrete_action_space_map.py
Normal file
32
rl_coach/filters/action/full_discrete_action_space_map.py
Normal file
@@ -0,0 +1,32 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.filters.action.partial_discrete_action_space_map import PartialDiscreteActionSpaceMap
|
||||
from rl_coach.spaces import ActionSpace, DiscreteActionSpace
|
||||
|
||||
|
||||
class FullDiscreteActionSpaceMap(PartialDiscreteActionSpaceMap):
|
||||
"""
|
||||
Maps all the actions in the output space to discrete actions in the action space.
|
||||
For example, if there are 10 multiselect actions in the output space, the actions 0-9 will be mapped to those
|
||||
multiselect actions.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
|
||||
self.target_actions = output_action_space.actions
|
||||
return super().get_unfiltered_action_space(output_action_space)
|
||||
60
rl_coach/filters/action/linear_box_to_box_map.py
Normal file
60
rl_coach/filters/action/linear_box_to_box_map.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import BoxActionSpace
|
||||
|
||||
from rl_coach.core_types import ActionType
|
||||
from rl_coach.filters.action.action_filter import ActionFilter
|
||||
|
||||
|
||||
class LinearBoxToBoxMap(ActionFilter):
|
||||
"""
|
||||
Maps a box action space to a box action space.
|
||||
For example,
|
||||
- the source action space has actions of shape 1 with values between -42 and -10,
|
||||
- the target action space has actions of shape 1 with values between 10 and 32
|
||||
The mapping will add an offset of 52 to the incoming actions and then multiply them by 22/32 to scale them to the
|
||||
target action space
|
||||
The shape of the source and target action spaces is always the same
|
||||
"""
|
||||
def __init__(self,
|
||||
input_space_low: Union[None, int, float, np.ndarray],
|
||||
input_space_high: Union[None, int, float, np.ndarray]):
|
||||
self.input_space_low = input_space_low
|
||||
self.input_space_high = input_space_high
|
||||
self.rescale = None
|
||||
self.offset = None
|
||||
super().__init__()
|
||||
|
||||
def validate_output_action_space(self, output_action_space: BoxActionSpace):
|
||||
if not isinstance(output_action_space, BoxActionSpace):
|
||||
raise ValueError("BoxActionSpace discretization only works with an output space of type BoxActionSpace. "
|
||||
"The given output space is {}".format(output_action_space))
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: BoxActionSpace) -> BoxActionSpace:
|
||||
self.input_action_space = BoxActionSpace(output_action_space.shape, self.input_space_low, self.input_space_high)
|
||||
self.rescale = \
|
||||
(output_action_space.high - output_action_space.low) / (self.input_space_high - self.input_space_low)
|
||||
self.offset = output_action_space.low - self.input_space_low
|
||||
self.output_action_space = output_action_space
|
||||
return self.input_action_space
|
||||
|
||||
def filter(self, action: ActionType) -> ActionType:
|
||||
return self.output_action_space.low + (action - self.input_space_low) * self.rescale
|
||||
|
||||
54
rl_coach/filters/action/partial_discrete_action_space_map.py
Normal file
54
rl_coach/filters/action/partial_discrete_action_space_map.py
Normal file
@@ -0,0 +1,54 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List
|
||||
|
||||
from rl_coach.spaces import DiscreteActionSpace, ActionSpace
|
||||
|
||||
from rl_coach.core_types import ActionType
|
||||
from rl_coach.filters.action.action_filter import ActionFilter
|
||||
|
||||
|
||||
class PartialDiscreteActionSpaceMap(ActionFilter):
|
||||
"""
|
||||
Maps the given actions from the output space to discrete actions in the action space.
|
||||
For example, if there are 10 multiselect actions in the output space, the actions 0-9 will be mapped to those
|
||||
multiselect actions.
|
||||
"""
|
||||
def __init__(self, target_actions: List[ActionType]=None, descriptions: List[str]=None):
|
||||
self.target_actions = target_actions
|
||||
self.descriptions = descriptions
|
||||
super().__init__()
|
||||
|
||||
def validate_output_action_space(self, output_action_space: ActionSpace):
|
||||
if not self.target_actions:
|
||||
raise ValueError("The target actions were not set")
|
||||
for v in self.target_actions:
|
||||
if not output_action_space.val_matches_space_definition(v):
|
||||
raise ValueError("The values in the output actions ({}) do not match the output action "
|
||||
"space definition ({})".format(v, output_action_space))
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> DiscreteActionSpace:
|
||||
self.output_action_space = output_action_space
|
||||
self.input_action_space = DiscreteActionSpace(len(self.target_actions), self.descriptions)
|
||||
return self.input_action_space
|
||||
|
||||
def filter(self, action: ActionType) -> ActionType:
|
||||
return self.target_actions[action]
|
||||
|
||||
def reverse_filter(self, action: ActionType) -> ActionType:
|
||||
return [(action == x).all() for x in self.target_actions].index(True)
|
||||
|
||||
418
rl_coach/filters/filter.py
Normal file
418
rl_coach/filters/filter.py
Normal file
@@ -0,0 +1,418 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Dict, Union, List
|
||||
|
||||
from rl_coach.spaces import ActionSpace, RewardSpace, ObservationSpace
|
||||
from rl_coach.core_types import EnvResponse, ActionInfo, Transition
|
||||
from rl_coach.utils import force_list
|
||||
|
||||
|
||||
class Filter(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Called from reset() and implements the reset logic for the filter.
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def filter(self, env_response: Union[EnvResponse, Transition], update_internal_state: bool=True) \
|
||||
-> Union[EnvResponse, Transition]:
|
||||
"""
|
||||
Filter some values in the env and return the filtered env_response
|
||||
This is the function that each filter should update
|
||||
:param update_internal_state: should the filter's internal state change due to this call
|
||||
:param env_response: the input env_response
|
||||
:return: the filtered env_response
|
||||
"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the session if it is required to use tensorflow ops
|
||||
:param sess: the session
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OutputFilter(Filter):
|
||||
"""
|
||||
An output filter is a module that filters the output from an agent to the environment.
|
||||
"""
|
||||
def __init__(self, action_filters: OrderedDict([(str, 'ActionFilter')])=None,
|
||||
is_a_reference_filter: bool=False):
|
||||
super().__init__()
|
||||
|
||||
if action_filters is None:
|
||||
action_filters = OrderedDict([])
|
||||
self._action_filters = action_filters
|
||||
|
||||
# We do not want to allow reference filters such as Atari to be used directly. These have to be duplicated first
|
||||
# and only then can change their values so to keep their original params intact for other agents in the graph.
|
||||
self.i_am_a_reference_filter = is_a_reference_filter
|
||||
|
||||
def __call__(self):
|
||||
duplicate = deepcopy(self)
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.action_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the session if it is required to use tensorflow ops
|
||||
:param sess: the session
|
||||
:return: None
|
||||
"""
|
||||
[f.set_session(sess) for f in self.action_filters.values()]
|
||||
|
||||
def filter(self, action_info: ActionInfo) -> ActionInfo:
|
||||
"""
|
||||
A wrapper around _filter which first copies the action_info so that we don't change the original one
|
||||
This function should not be updated!
|
||||
:param action_info: the input action_info
|
||||
:return: the filtered action_info
|
||||
"""
|
||||
if self.i_am_a_reference_filter:
|
||||
raise Exception("The filter being used is a reference filter. It is not to be used directly. "
|
||||
"Instead get a duplicate from it by calling __call__.")
|
||||
if len(self.action_filters.values()) == 0:
|
||||
return action_info
|
||||
filtered_action_info = copy.deepcopy(action_info)
|
||||
filtered_action = filtered_action_info.action
|
||||
for filter in reversed(self.action_filters.values()):
|
||||
filtered_action = filter.filter(filtered_action)
|
||||
|
||||
filtered_action_info.action = filtered_action
|
||||
|
||||
return filtered_action_info
|
||||
|
||||
def reverse_filter(self, action_info: ActionInfo) -> ActionInfo:
|
||||
"""
|
||||
A wrapper around _reverse_filter which first copies the action_info so that we don't change the original one
|
||||
This function should not be updated!
|
||||
:param action_info: the input action_info
|
||||
:return: the filtered action_info
|
||||
"""
|
||||
if self.i_am_a_reference_filter:
|
||||
raise Exception("The filter being used is a reference filter. It is not to be used directly. "
|
||||
"Instead get a duplicate from it by calling __call__.")
|
||||
filtered_action_info = copy.deepcopy(action_info)
|
||||
filtered_action = filtered_action_info.action
|
||||
for filter in self.action_filters.values():
|
||||
filter.validate_output_action(filtered_action)
|
||||
filtered_action = filter.reverse_filter(filtered_action)
|
||||
|
||||
filtered_action_info.action = filtered_action
|
||||
|
||||
return filtered_action_info
|
||||
|
||||
def get_unfiltered_action_space(self, output_action_space: ActionSpace) -> ActionSpace:
|
||||
"""
|
||||
Given the output action space, returns the corresponding unfiltered action space
|
||||
This function should not be updated!
|
||||
:param output_action_space: the output action space
|
||||
:return: the unfiltered action space
|
||||
"""
|
||||
unfiltered_action_space = copy.deepcopy(output_action_space)
|
||||
for filter in self._action_filters.values():
|
||||
new_unfiltered_action_space = filter.get_unfiltered_action_space(unfiltered_action_space)
|
||||
filter.validate_output_action_space(unfiltered_action_space)
|
||||
unfiltered_action_space = new_unfiltered_action_space
|
||||
return unfiltered_action_space
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset any internal memory stored in the filter.
|
||||
This function should not be updated!
|
||||
This is useful for stateful filters which stores information on previous filter calls.
|
||||
:return: None
|
||||
"""
|
||||
[action_filter.reset() for action_filter in self._action_filters.values()]
|
||||
|
||||
@property
|
||||
def action_filters(self) -> OrderedDict([(str, 'ActionFilter')]):
|
||||
return self._action_filters
|
||||
|
||||
@action_filters.setter
|
||||
def action_filters(self, val: OrderedDict([(str, 'ActionFilter')])):
|
||||
self._action_filters = val
|
||||
|
||||
def add_action_filter(self, filter_name: str, filter: 'ActionFilter', add_as_the_first_filter: bool=False):
|
||||
"""
|
||||
Add an action filter to the filters list
|
||||
:param filter_name: the filter name
|
||||
:param filter: the filter to add
|
||||
:param add_as_the_first_filter: add the filter to the top of the filters stack
|
||||
:return: None
|
||||
"""
|
||||
self._action_filters[filter_name] = filter
|
||||
if add_as_the_first_filter:
|
||||
self._action_filters.move_to_end(filter_name, last=False)
|
||||
|
||||
def remove_action_filter(self, filter_name: str) -> None:
|
||||
"""
|
||||
Remove an action filter from the filters list
|
||||
:param filter_name: the filter name
|
||||
:return: None
|
||||
"""
|
||||
del self._action_filters[filter_name]
|
||||
|
||||
|
||||
class NoOutputFilter(OutputFilter):
|
||||
"""
|
||||
Creates an empty output filter. Used only for readability when creating the presets
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(is_a_reference_filter=False)
|
||||
|
||||
|
||||
class InputFilter(Filter):
|
||||
"""
|
||||
An input filter is a module that filters the input from an environment to the agent.
|
||||
"""
|
||||
def __init__(self, observation_filters: Dict[str, Dict[str, 'ObservationFilter']]=None,
|
||||
reward_filters: Dict[str, 'RewardFilter']=None,
|
||||
is_a_reference_filter: bool=False):
|
||||
super().__init__()
|
||||
if observation_filters is None:
|
||||
observation_filters = {}
|
||||
if reward_filters is None:
|
||||
reward_filters = OrderedDict([])
|
||||
self._observation_filters = observation_filters
|
||||
self._reward_filters = reward_filters
|
||||
|
||||
# We do not want to allow reference filters such as Atari to be used directly. These have to be duplicated first
|
||||
# and only then can change their values so to keep their original params intact for other agents in the graph.
|
||||
self.i_am_a_reference_filter = is_a_reference_filter
|
||||
|
||||
def __call__(self):
|
||||
duplicate = deepcopy(self)
|
||||
duplicate.i_am_a_reference_filter = False
|
||||
return duplicate
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
[f.set_device(device) for f in self.reward_filters.values()]
|
||||
[[f.set_device(device) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the session if it is required to use tensorflow ops
|
||||
:param sess: the session
|
||||
:return: None
|
||||
"""
|
||||
[f.set_session(sess) for f in self.reward_filters.values()]
|
||||
[[f.set_session(sess) for f in filters.values()] for filters in self.observation_filters.values()]
|
||||
|
||||
def filter(self, unfiltered_data: Union[EnvResponse, List[EnvResponse], Transition, List[Transition]],
|
||||
update_internal_state: bool=True, deep_copy: bool=True) -> Union[List[EnvResponse], List[Transition]]:
|
||||
"""
|
||||
A wrapper around _filter which first copies the env_response so that we don't change the original one
|
||||
This function should not be updated!
|
||||
:param unfiltered_data: the input data
|
||||
:param update_internal_state: should the filter's internal state change due to this call
|
||||
:return: the filtered env_response
|
||||
"""
|
||||
if self.i_am_a_reference_filter:
|
||||
raise Exception("The filter being used is a reference filter. It is not to be used directly. "
|
||||
"Instead get a duplicate from it by calling __call__.")
|
||||
if deep_copy:
|
||||
filtered_data = copy.deepcopy(unfiltered_data)
|
||||
else:
|
||||
filtered_data = [copy.copy(t) for t in unfiltered_data]
|
||||
filtered_data = force_list(filtered_data)
|
||||
|
||||
# TODO: implement observation space validation
|
||||
# filter observations
|
||||
if isinstance(filtered_data[0], Transition):
|
||||
state_objects_to_filter = [[f.state for f in filtered_data],
|
||||
[f.next_state for f in filtered_data]]
|
||||
elif isinstance(filtered_data[0], EnvResponse):
|
||||
state_objects_to_filter = [[f.next_state for f in filtered_data]]
|
||||
else:
|
||||
raise ValueError("unfiltered_data should be either of type EnvResponse or Transition. ")
|
||||
|
||||
for state_object_list in state_objects_to_filter:
|
||||
for observation_name, filters in self._observation_filters.items():
|
||||
if observation_name in state_object_list[0].keys():
|
||||
for filter in filters.values():
|
||||
data_to_filter = [state_object[observation_name] for state_object in state_object_list]
|
||||
if filter.supports_batching:
|
||||
filtered_observations = filter.filter(
|
||||
data_to_filter, update_internal_state=update_internal_state)
|
||||
else:
|
||||
filtered_observations = []
|
||||
for data_point in data_to_filter:
|
||||
filtered_observations.append(filter.filter(
|
||||
data_point, update_internal_state=update_internal_state))
|
||||
|
||||
for i, state_object in enumerate(state_object_list):
|
||||
state_object[observation_name] = filtered_observations[i]
|
||||
|
||||
# filter reward
|
||||
for f in filtered_data:
|
||||
filtered_reward = f.reward
|
||||
for filter in self._reward_filters.values():
|
||||
filtered_reward = filter.filter(filtered_reward, update_internal_state)
|
||||
f.reward = filtered_reward
|
||||
|
||||
return filtered_data
|
||||
|
||||
|
||||
def get_filtered_observation_space(self, observation_name: str,
|
||||
input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
"""
|
||||
Given the input observation space, returns the corresponding filtered observation space
|
||||
This function should not be updated!
|
||||
:param observation_name: the name of the observation to which we want to calculate the filtered space
|
||||
:param input_observation_space: the input observation space
|
||||
:return: the filtered observation space
|
||||
"""
|
||||
filtered_observation_space = copy.deepcopy(input_observation_space)
|
||||
if observation_name in self._observation_filters.keys():
|
||||
for filter in self._observation_filters[observation_name].values():
|
||||
filter.validate_input_observation_space(filtered_observation_space)
|
||||
filtered_observation_space = filter.get_filtered_observation_space(filtered_observation_space)
|
||||
return filtered_observation_space
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
"""
|
||||
Given the input reward space, returns the corresponding filtered reward space
|
||||
This function should not be updated!
|
||||
:param input_reward_space: the input reward space
|
||||
:return: the filtered reward space
|
||||
"""
|
||||
filtered_reward_space = copy.deepcopy(input_reward_space)
|
||||
for filter in self._reward_filters.values():
|
||||
filtered_reward_space = filter.get_filtered_reward_space(filtered_reward_space)
|
||||
return filtered_reward_space
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset any internal memory stored in the filter.
|
||||
This function should not be updated!
|
||||
This is useful for stateful filters which stores information on previous filter calls.
|
||||
:return: None
|
||||
"""
|
||||
for curr_observation_filters in self._observation_filters.values():
|
||||
[observation_filter.reset() for observation_filter in curr_observation_filters.values()]
|
||||
[reward_filter.reset() for reward_filter in self._reward_filters.values()]
|
||||
|
||||
@property
|
||||
def observation_filters(self) -> Dict[str, Dict[str, 'ObservationFilter']]:
|
||||
return self._observation_filters
|
||||
|
||||
@observation_filters.setter
|
||||
def observation_filters(self, val: Dict[str, Dict[str, 'ObservationFilter']]):
|
||||
self._observation_filters = val
|
||||
|
||||
@property
|
||||
def reward_filters(self) -> OrderedDict([(str, 'RewardFilter')]):
|
||||
return self._reward_filters
|
||||
|
||||
@reward_filters.setter
|
||||
def reward_filters(self, val: OrderedDict([(str, 'RewardFilter')])):
|
||||
self._reward_filters = val
|
||||
|
||||
def copy_filters_from_one_observation_to_another(self, from_observation: str, to_observation: str):
|
||||
"""
|
||||
Copy all the filters created for some observation to another observation
|
||||
:param from_observation: the source observation to copy from
|
||||
:param to_observation: the target observation to copy to
|
||||
:return: None
|
||||
"""
|
||||
self._observation_filters[to_observation] = copy.deepcopy(self._observation_filters[from_observation])
|
||||
|
||||
def add_observation_filter(self, observation_name: str, filter_name: str, filter: 'ObservationFilter',
|
||||
add_as_the_first_filter: bool=False):
|
||||
"""
|
||||
Add an observation filter to the filters list
|
||||
:param observation_name: the name of the observation to apply to
|
||||
:param filter_name: the filter name
|
||||
:param filter: the filter to add
|
||||
:param add_as_the_first_filter: add the filter to the top of the filters stack
|
||||
:return: None
|
||||
"""
|
||||
if observation_name not in self._observation_filters.keys():
|
||||
self._observation_filters[observation_name] = OrderedDict()
|
||||
self._observation_filters[observation_name][filter_name] = filter
|
||||
if add_as_the_first_filter:
|
||||
self._observation_filters[observation_name].move_to_end(filter_name, last=False)
|
||||
|
||||
def add_reward_filter(self, filter_name: str, filter: 'RewardFilter', add_as_the_first_filter: bool=False):
|
||||
"""
|
||||
Add a reward filter to the filters list
|
||||
:param filter_name: the filter name
|
||||
:param filter: the filter to add
|
||||
:param add_as_the_first_filter: add the filter to the top of the filters stack
|
||||
:return: None
|
||||
"""
|
||||
self._reward_filters[filter_name] = filter
|
||||
if add_as_the_first_filter:
|
||||
self._reward_filters.move_to_end(filter_name, last=False)
|
||||
|
||||
def remove_observation_filter(self, observation_name: str, filter_name: str) -> None:
|
||||
"""
|
||||
Remove an observation filter from the filters list
|
||||
:param observation_name: the name of the observation to apply to
|
||||
:param filter_name: the filter name
|
||||
:return: None
|
||||
"""
|
||||
del self._observation_filters[observation_name][filter_name]
|
||||
|
||||
def remove_reward_filter(self, filter_name: str) -> None:
|
||||
"""
|
||||
Remove a reward filter from the filters list
|
||||
:param filter_name: the filter name
|
||||
:return: None
|
||||
"""
|
||||
del self._reward_filters[filter_name]
|
||||
|
||||
|
||||
class NoInputFilter(InputFilter):
|
||||
"""
|
||||
Creates an empty input filter. Used only for readability when creating the presets
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(is_a_reference_filter=False)
|
||||
|
||||
|
||||
0
rl_coach/filters/observation/__init__.py
Normal file
0
rl_coach/filters/observation/__init__.py
Normal file
44
rl_coach/filters/observation/observation_clipping_filter.py
Normal file
44
rl_coach/filters/observation/observation_clipping_filter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationClippingFilter(ObservationFilter):
|
||||
"""
|
||||
Clip the observation values using the given ranges
|
||||
"""
|
||||
def __init__(self, clipping_low: float=-np.inf, clipping_high: float=np.inf):
|
||||
"""
|
||||
:param clipping_low: The minimum value to allow after normalizing the observation
|
||||
:param clipping_high: The maximum value to allow after normalizing the observation
|
||||
"""
|
||||
super().__init__()
|
||||
self.clip_min = clipping_low
|
||||
self.clip_max = clipping_high
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
observation = np.clip(observation, self.clip_min, self.clip_max)
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
return input_observation_space
|
||||
92
rl_coach/filters/observation/observation_crop_filter.py
Normal file
92
rl_coach/filters/observation/observation_crop_filter.py
Normal file
@@ -0,0 +1,92 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationCropFilter(ObservationFilter):
|
||||
"""
|
||||
Crops the current state observation to a given shape
|
||||
"""
|
||||
def __init__(self, crop_low: np.ndarray=None, crop_high: np.ndarray=None):
|
||||
"""
|
||||
:param crop_low: a vector where each dimension describes the start index for cropping the observation in the
|
||||
corresponding dimension. a negative value of -1 will be mapped to the max size
|
||||
:param crop_high: a vector where each dimension describes the end index for cropping the observation in the
|
||||
corresponding dimension. a negative value of -1 will be mapped to the max size
|
||||
"""
|
||||
super().__init__()
|
||||
if crop_low is None and crop_high is None:
|
||||
raise ValueError("At least one of crop_low and crop_high should be set to a real value. ")
|
||||
if crop_low is None:
|
||||
crop_low = np.array([0] * len(crop_high))
|
||||
if crop_high is None:
|
||||
crop_high = np.array([-1] * len(crop_low))
|
||||
|
||||
self.crop_low = crop_low
|
||||
self.crop_high = crop_high
|
||||
|
||||
for h, l in zip(crop_high, crop_low):
|
||||
if h < l and h != -1:
|
||||
raise ValueError("Some of the cropping low values are higher than cropping high values")
|
||||
if np.any(crop_high < -1) or np.any(crop_low < -1):
|
||||
raise ValueError("Cropping values cannot be negative")
|
||||
if crop_low.shape != crop_high.shape:
|
||||
raise ValueError("The low values and high values for cropping must have the same number of dimensions")
|
||||
if crop_low.dtype != int or crop_high.dtype != int:
|
||||
raise ValueError("The crop values should be int values, instead they are defined as: {} and {}"
|
||||
.format(crop_low.dtype, crop_high.dtype))
|
||||
|
||||
def _replace_negative_one_in_crop_size(self, crop_size: np.ndarray, observation_shape: Union[Tuple, np.ndarray]):
|
||||
# replace -1 with the max size
|
||||
crop_size = crop_size.copy()
|
||||
for i in range(len(observation_shape)):
|
||||
if crop_size[i] == -1:
|
||||
crop_size[i] = observation_shape[i]
|
||||
return crop_size
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
crop_high = self._replace_negative_one_in_crop_size(self.crop_high, input_observation_space.shape)
|
||||
crop_low = self._replace_negative_one_in_crop_size(self.crop_low, input_observation_space.shape)
|
||||
if np.any(crop_high > input_observation_space.shape) or \
|
||||
np.any(crop_low > input_observation_space.shape):
|
||||
raise ValueError("The cropping values are outside of the observation space")
|
||||
if not input_observation_space.is_point_in_space_shape(crop_low) or \
|
||||
not input_observation_space.is_point_in_space_shape(crop_high - 1):
|
||||
raise ValueError("The cropping indices are outside of the observation space")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# replace -1 with the max size
|
||||
crop_high = self._replace_negative_one_in_crop_size(self.crop_high, observation.shape)
|
||||
crop_low = self._replace_negative_one_in_crop_size(self.crop_low, observation.shape)
|
||||
|
||||
# crop
|
||||
indices = [slice(i, j) for i, j in zip(crop_low, crop_high)]
|
||||
observation = observation[indices]
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
# replace -1 with the max size
|
||||
crop_high = self._replace_negative_one_in_crop_size(self.crop_high, input_observation_space.shape)
|
||||
crop_low = self._replace_negative_one_in_crop_size(self.crop_low, input_observation_space.shape)
|
||||
|
||||
input_observation_space.shape = crop_high - crop_low
|
||||
return input_observation_space
|
||||
40
rl_coach/filters/observation/observation_filter.py
Normal file
40
rl_coach/filters/observation/observation_filter.py
Normal file
@@ -0,0 +1,40 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.filters.filter import Filter
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
|
||||
class ObservationFilter(Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supports_batching = False
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
"""
|
||||
This function should contain the logic for getting the filtered observation space
|
||||
:param input_observation_space: the input observation space
|
||||
:return: the filtered observation space
|
||||
"""
|
||||
return input_observation_space
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
"""
|
||||
A function that implements validation of the input observation space
|
||||
:param input_observation_space: the input observation space
|
||||
:return: None
|
||||
"""
|
||||
pass
|
||||
62
rl_coach/filters/observation/observation_move_axis_filter.py
Normal file
62
rl_coach/filters/observation/observation_move_axis_filter.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationMoveAxisFilter(ObservationFilter):
|
||||
"""
|
||||
Move an axis of the observation to a different place.
|
||||
"""
|
||||
def __init__(self, axis_origin: int = None, axis_target: int=None):
|
||||
super().__init__()
|
||||
self.axis_origin = axis_origin
|
||||
self.axis_target = axis_target
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
shape = input_observation_space.shape
|
||||
if not -len(shape) <= self.axis_origin < len(shape) or not -len(shape) <= self.axis_target < len(shape):
|
||||
raise ValueError("The given axis does not exist in the context of the input observation shape. ")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
return np.moveaxis(observation, self.axis_origin, self.axis_target)
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
axis_size = input_observation_space.shape[self.axis_origin]
|
||||
input_observation_space.shape = np.delete(input_observation_space.shape, self.axis_origin)
|
||||
if self.axis_target == -1:
|
||||
input_observation_space.shape = np.append(input_observation_space.shape, axis_size)
|
||||
elif self.axis_target < -1:
|
||||
input_observation_space.shape = np.insert(input_observation_space.shape, self.axis_target+1, axis_size)
|
||||
else:
|
||||
input_observation_space.shape = np.insert(input_observation_space.shape, self.axis_target, axis_size)
|
||||
|
||||
# move the channels axis according to the axis change
|
||||
if isinstance(input_observation_space, PlanarMapsObservationSpace):
|
||||
if input_observation_space.channels_axis == self.axis_origin:
|
||||
input_observation_space.channels_axis = self.axis_target
|
||||
elif input_observation_space.channels_axis == self.axis_target:
|
||||
input_observation_space.channels_axis = self.axis_origin
|
||||
elif self.axis_origin < input_observation_space.channels_axis < self.axis_target:
|
||||
input_observation_space.channels_axis -= 1
|
||||
elif self.axis_target < input_observation_space.channels_axis < self.axis_origin:
|
||||
input_observation_space.channels_axis += 1
|
||||
|
||||
return input_observation_space
|
||||
@@ -0,0 +1,73 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationNormalizationFilter(ObservationFilter):
|
||||
"""
|
||||
Normalize the observation with a running standard deviation and mean of the observations seen so far
|
||||
If there is more than a single worker, the statistics of the observations are shared between all the workers
|
||||
"""
|
||||
def __init__(self, clip_min: float=-5.0, clip_max: float=5.0, name='observation_stats'):
|
||||
"""
|
||||
:param clip_min: The minimum value to allow after normalizing the observation
|
||||
:param clip_max: The maximum value to allow after normalizing the observation
|
||||
"""
|
||||
super().__init__()
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
self.running_observation_stats = None
|
||||
self.name = name
|
||||
self.supports_batching = True
|
||||
self.observation_space = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_observation_stats = SharedRunningStats(device, name=self.name, create_ops=False)
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the session if it is required to use tensorflow ops
|
||||
:param sess: the session
|
||||
:return: None
|
||||
"""
|
||||
self.running_observation_stats.set_session(sess)
|
||||
|
||||
def filter(self, observations: List[ObservationType], update_internal_state: bool=True) -> ObservationType:
|
||||
observations = np.array(observations)
|
||||
if update_internal_state:
|
||||
self.running_observation_stats.push(observations)
|
||||
self.last_mean = self.running_observation_stats.mean
|
||||
self.last_stdev = self.running_observation_stats.std
|
||||
|
||||
# TODO: make sure that a batch is given here
|
||||
return self.running_observation_stats.normalize(observations)
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
self.running_observation_stats.create_ops(shape=input_observation_space.shape,
|
||||
clip_values=(self.clip_min, self.clip_max))
|
||||
return input_observation_space
|
||||
@@ -0,0 +1,76 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from rl_coach.spaces import ObservationSpace, VectorObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationReductionBySubPartsNameFilter(ObservationFilter):
|
||||
"""
|
||||
Choose sub parts of the observation to remove or keep using their name.
|
||||
This is useful when the environment has a measurements vector as observation which includes several different
|
||||
measurements, but you want the agent to only see some of the measurements and not all.
|
||||
This will currently work only for VectorObservationSpace observations
|
||||
"""
|
||||
class ReductionMethod(Enum):
|
||||
Keep = 0
|
||||
Discard = 1
|
||||
|
||||
def __init__(self, part_names: List[str], reduction_method: ReductionMethod):
|
||||
"""
|
||||
:param part_names: A list of part names to reduce
|
||||
:param reduction_method: A reduction method to use - keep or discard the given parts
|
||||
"""
|
||||
super().__init__()
|
||||
self.part_names = part_names
|
||||
self.reduction_method = reduction_method
|
||||
self.measurement_names = None
|
||||
self.indices_to_keep = None
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
if self.indices_to_keep is None:
|
||||
raise ValueError("To use ObservationReductionBySubPartsNameFilter, the get_filtered_observation_space "
|
||||
"function should be called before filtering an observation")
|
||||
observation = observation[..., self.indices_to_keep]
|
||||
return observation
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if not isinstance(input_observation_space, VectorObservationSpace):
|
||||
raise ValueError("The ObservationReductionBySubPartsNameFilter support only VectorObservationSpace "
|
||||
"observations. The given observation space was: {}"
|
||||
.format(input_observation_space.__class__))
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: VectorObservationSpace) -> ObservationSpace:
|
||||
self.measurement_names = copy.copy(input_observation_space.measurements_names)
|
||||
|
||||
if self.reduction_method == self.ReductionMethod.Keep:
|
||||
input_observation_space.shape[-1] = len(self.part_names)
|
||||
self.indices_to_keep = [idx for idx, val in enumerate(self.measurement_names) if val in self.part_names]
|
||||
input_observation_space.measurements_names = copy.copy(self.part_names)
|
||||
elif self.reduction_method == self.ReductionMethod.Discard:
|
||||
input_observation_space.shape[-1] -= len(self.part_names)
|
||||
self.indices_to_keep = [idx for idx, val in enumerate(self.measurement_names) if val not in self.part_names]
|
||||
input_observation_space.measurements_names = [val for val in input_observation_space.measurements_names if
|
||||
val not in self.part_names]
|
||||
else:
|
||||
raise ValueError("The given reduction method is not supported")
|
||||
|
||||
return input_observation_space
|
||||
@@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import scipy.ndimage
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
# imresize interpolation types as defined by scipy here:
|
||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
||||
class RescaleInterpolationType(Enum):
|
||||
NEAREST = 'nearest'
|
||||
LANCZOS = 'lanczos'
|
||||
BILINEAR = 'bilinear'
|
||||
BICUBIC = 'bicubic'
|
||||
CUBIC = 'cubic'
|
||||
|
||||
|
||||
class ObservationRescaleSizeByFactorFilter(ObservationFilter):
|
||||
"""
|
||||
Scales the current state observation size by a given factor
|
||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
||||
"""
|
||||
def __init__(self, rescale_factor: float, rescaling_interpolation_type: RescaleInterpolationType):
|
||||
"""
|
||||
:param rescale_factor: the factor by which the observation will be rescaled
|
||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
||||
"""
|
||||
super().__init__()
|
||||
self.rescale_factor = float(rescale_factor) # scipy requires float scale factors
|
||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
||||
# TODO: allow selecting the channels dim
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if not 2 <= input_observation_space.num_dimensions <= 3:
|
||||
raise ValueError("The rescale filter only applies to image observations where the number of dimensions is"
|
||||
"either 2 (grayscale) or 3 (RGB). The number of dimensions defined for the "
|
||||
"output observation was {}".format(input_observation_space.num_dimensions))
|
||||
if input_observation_space.num_dimensions == 3 and input_observation_space.shape[-1] != 3:
|
||||
raise ValueError("Observations with 3 dimensions must have 3 channels in the last axis (RGB)")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# scipy works only with uint8
|
||||
observation = observation.astype('uint8')
|
||||
|
||||
# rescale
|
||||
observation = scipy.misc.imresize(observation,
|
||||
self.rescale_factor,
|
||||
interp=self.rescaling_interpolation_type.value)
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
input_observation_space.shape[:2] = (input_observation_space.shape[:2] * self.rescale_factor).astype('int')
|
||||
return input_observation_space
|
||||
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
import scipy.ndimage
|
||||
from rl_coach.spaces import ObservationSpace, PlanarMapsObservationSpace, ImageObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
# imresize interpolation types as defined by scipy here:
|
||||
# https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.misc.imresize.html
|
||||
class RescaleInterpolationType(Enum):
|
||||
NEAREST = 'nearest'
|
||||
LANCZOS = 'lanczos'
|
||||
BILINEAR = 'bilinear'
|
||||
BICUBIC = 'bicubic'
|
||||
CUBIC = 'cubic'
|
||||
|
||||
|
||||
class ObservationRescaleToSizeFilter(ObservationFilter):
|
||||
"""
|
||||
Scales the current state observation to a given shape
|
||||
Warning: this requires the input observation to be of type uint8 due to scipy requirements!
|
||||
"""
|
||||
def __init__(self, output_observation_space: PlanarMapsObservationSpace,
|
||||
rescaling_interpolation_type: RescaleInterpolationType=RescaleInterpolationType.BILINEAR):
|
||||
"""
|
||||
:param output_observation_space: the output observation space
|
||||
:param rescaling_interpolation_type: the interpolation type for rescaling
|
||||
"""
|
||||
super().__init__()
|
||||
self.output_observation_space = output_observation_space
|
||||
self.rescaling_interpolation_type = rescaling_interpolation_type
|
||||
|
||||
if not isinstance(output_observation_space, PlanarMapsObservationSpace):
|
||||
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
|
||||
"PlanarMapsObservationSpace. This includes observations which consist of a set of 2D "
|
||||
"images or an RGB image. Instead the output observation space was defined as: {}"
|
||||
.format(output_observation_space.__class__))
|
||||
|
||||
self.planar_map_output_shape = copy.copy(self.output_observation_space.shape)
|
||||
self.planar_map_output_shape = np.delete(self.planar_map_output_shape,
|
||||
self.output_observation_space.channels_axis)
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if not isinstance(input_observation_space, PlanarMapsObservationSpace):
|
||||
raise ValueError("The rescale filter only applies to observation spaces that inherit from "
|
||||
"PlanarMapsObservationSpace. This includes observations which consist of a set of 2D "
|
||||
"images or an RGB image. Instead the input observation space was defined as: {}"
|
||||
.format(input_observation_space.__class__))
|
||||
if input_observation_space.shape[input_observation_space.channels_axis] \
|
||||
!= self.output_observation_space.shape[self.output_observation_space.channels_axis]:
|
||||
raise ValueError("The number of channels between the input and output observation spaces must match. "
|
||||
"Instead the number of channels were: {}, {}"
|
||||
.format(input_observation_space.shape[input_observation_space.channels_axis],
|
||||
self.output_observation_space.shape[self.output_observation_space.channels_axis]))
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# scipy works only with uint8
|
||||
observation = observation.astype('uint8')
|
||||
|
||||
# rescale
|
||||
if isinstance(self.output_observation_space, ImageObservationSpace):
|
||||
observation = scipy.misc.imresize(observation,
|
||||
tuple(self.output_observation_space.shape),
|
||||
interp=self.rescaling_interpolation_type.value)
|
||||
else:
|
||||
new_observation = []
|
||||
for i in range(self.output_observation_space.shape[self.output_observation_space.channels_axis]):
|
||||
new_observation.append(scipy.misc.imresize(observation.take(i, self.output_observation_space.channels_axis),
|
||||
tuple(self.planar_map_output_shape),
|
||||
interp=self.rescaling_interpolation_type.value))
|
||||
new_observation = np.array(new_observation)
|
||||
observation = new_observation.swapaxes(0, self.output_observation_space.channels_axis)
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
input_observation_space.shape = self.output_observation_space.shape
|
||||
return input_observation_space
|
||||
50
rl_coach/filters/observation/observation_rgb_to_y_filter.py
Normal file
50
rl_coach/filters/observation/observation_rgb_to_y_filter.py
Normal file
@@ -0,0 +1,50 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationRGBToYFilter(ObservationFilter):
|
||||
"""
|
||||
Converts the observation in the current state to gray scale (Y channel).
|
||||
The channels axis is assumed to be the last axis
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if input_observation_space.num_dimensions != 3:
|
||||
raise ValueError("The rescale filter only applies to image observations where the number of dimensions is"
|
||||
"3 (RGB). The number of dimensions defined for the input observation was {}"
|
||||
.format(input_observation_space.num_dimensions))
|
||||
if input_observation_space.shape[-1] != 3:
|
||||
raise ValueError("The observation space is expected to have 3 channels in the 1st dimension. The number of "
|
||||
"dimensions received is {}".format(input_observation_space.shape[-1]))
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
|
||||
# rgb to y
|
||||
r, g, b = observation[:, :, 0], observation[:, :, 1], observation[:, :, 2]
|
||||
observation = 0.2989 * r + 0.5870 * g + 0.1140 * b
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
input_observation_space.shape = input_observation_space.shape[:-1]
|
||||
return input_observation_space
|
||||
46
rl_coach/filters/observation/observation_squeeze_filter.py
Normal file
46
rl_coach/filters/observation/observation_squeeze_filter.py
Normal file
@@ -0,0 +1,46 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationSqueezeFilter(ObservationFilter):
|
||||
"""
|
||||
Squeezes the observation so to eliminate redundant axes.
|
||||
"""
|
||||
def __init__(self, axis: int = None):
|
||||
super().__init__()
|
||||
self.axis = axis
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if self.axis is None:
|
||||
return
|
||||
|
||||
shape = input_observation_space.shape
|
||||
if self.axis >= len(shape) or self.axis < -len(shape):
|
||||
raise ValueError("The given axis does not exist in the context of the input observation shape. ")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
return observation.squeeze(axis=self.axis)
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
dummy_tensor = np.random.rand(*tuple(input_observation_space.shape))
|
||||
input_observation_space.shape = dummy_tensor.squeeze(axis=self.axis).shape
|
||||
return input_observation_space
|
||||
105
rl_coach/filters/observation/observation_stacking_filter.py
Normal file
105
rl_coach/filters/observation/observation_stacking_filter.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class LazyStack(object):
|
||||
"""
|
||||
A lazy version of np.stack which avoids copying the memory until it is
|
||||
needed.
|
||||
"""
|
||||
|
||||
def __init__(self, history, axis=None):
|
||||
self.history = copy.copy(history)
|
||||
self.axis = axis
|
||||
|
||||
def __array__(self, dtype=None):
|
||||
array = np.stack(self.history, axis=self.axis)
|
||||
if dtype is not None:
|
||||
array = array.astype(dtype)
|
||||
return array
|
||||
|
||||
|
||||
class ObservationStackingFilter(ObservationFilter):
|
||||
"""
|
||||
Stack the current state observation on top of several previous observations.
|
||||
This filter is stateful since it stores the previous step result and depends on it.
|
||||
The filter adds an additional dimension to the output observation.
|
||||
|
||||
Warning!!! The filter replaces the observation with a LazyStack object, so no filters should be
|
||||
applied after this filter. applying more filters will cause the LazyStack object to be converted to a numpy array
|
||||
and increase the memory footprint.
|
||||
"""
|
||||
def __init__(self, stack_size: int, stacking_axis: int=-1):
|
||||
"""
|
||||
:param stack_size: the number of previous observations in the stack
|
||||
:param stacking_axis: the axis on which to stack the observation on
|
||||
"""
|
||||
super().__init__()
|
||||
self.stack_size = stack_size
|
||||
self.stacking_axis = stacking_axis
|
||||
self.stack = []
|
||||
|
||||
if stack_size <= 0:
|
||||
raise ValueError("The stack shape must be a positive number")
|
||||
if type(stack_size) != int:
|
||||
raise ValueError("The stack shape must be of int type")
|
||||
|
||||
@property
|
||||
def next_filter(self) -> 'InputFilter':
|
||||
return self._next_filter
|
||||
|
||||
@next_filter.setter
|
||||
def next_filter(self, val: 'InputFilter'):
|
||||
raise ValueError("ObservationStackingFilter can have no other filters after it since they break its "
|
||||
"functionality")
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if len(self.stack) > 0 and not input_observation_space.val_matches_space_definition(self.stack[-1]):
|
||||
raise ValueError("The given input observation space is different than the observations already stored in"
|
||||
"the filters memory")
|
||||
if input_observation_space.num_dimensions <= self.stacking_axis:
|
||||
raise ValueError("The stacking axis is larger than the number of dimensions in the observation space")
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
|
||||
if len(self.stack) == 0:
|
||||
self.stack = deque([observation] * self.stack_size, maxlen=self.stack_size)
|
||||
else:
|
||||
if update_internal_state:
|
||||
self.stack.append(observation)
|
||||
observation = LazyStack(self.stack, self.stacking_axis)
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
if self.stacking_axis == -1:
|
||||
input_observation_space.shape = np.append(input_observation_space.shape, values=[self.stack_size], axis=0)
|
||||
else:
|
||||
input_observation_space.shape = np.insert(input_observation_space.shape, obj=self.stacking_axis,
|
||||
values=[self.stack_size], axis=0)
|
||||
return input_observation_space
|
||||
|
||||
def reset(self) -> None:
|
||||
self.stack = []
|
||||
60
rl_coach/filters/observation/observation_to_uint8_filter.py
Normal file
60
rl_coach/filters/observation/observation_to_uint8_filter.py
Normal file
@@ -0,0 +1,60 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import ObservationSpace
|
||||
|
||||
from rl_coach.core_types import ObservationType
|
||||
from rl_coach.filters.observation.observation_filter import ObservationFilter
|
||||
|
||||
|
||||
class ObservationToUInt8Filter(ObservationFilter):
|
||||
"""
|
||||
Converts the observation values to be uint8 values between 0 and 255.
|
||||
It first scales the observation values to fit in the range and then converts them to uint8.
|
||||
"""
|
||||
def __init__(self, input_low: float, input_high: float):
|
||||
super().__init__()
|
||||
self.input_low = input_low
|
||||
self.input_high = input_high
|
||||
|
||||
if input_high <= input_low:
|
||||
raise ValueError("The input observation space high values can be less or equal to the input observation "
|
||||
"space low values")
|
||||
|
||||
def validate_input_observation_space(self, input_observation_space: ObservationSpace):
|
||||
if np.all(input_observation_space.low != self.input_low) or \
|
||||
np.all(input_observation_space.high != self.input_high):
|
||||
raise ValueError("The observation space values range don't match the configuration of the filter."
|
||||
"The configuration is: low = {}, high = {}. The actual values are: low = {}, high = {}"
|
||||
.format(self.input_low, self.input_high,
|
||||
input_observation_space.low, input_observation_space.high))
|
||||
|
||||
def filter(self, observation: ObservationType, update_internal_state: bool=True) -> ObservationType:
|
||||
# scale to 0-1
|
||||
observation = (observation - self.input_low) / (self.input_high - self.input_low)
|
||||
|
||||
# scale to 0-255
|
||||
observation *= 255
|
||||
|
||||
observation = observation.astype('uint8')
|
||||
|
||||
return observation
|
||||
|
||||
def get_filtered_observation_space(self, input_observation_space: ObservationSpace) -> ObservationSpace:
|
||||
input_observation_space.low = 0
|
||||
input_observation_space.high = 255
|
||||
return input_observation_space
|
||||
0
rl_coach/filters/reward/__init__.py
Normal file
0
rl_coach/filters/reward/__init__.py
Normal file
53
rl_coach/filters/reward/reward_clipping_filter.py
Normal file
53
rl_coach/filters/reward/reward_clipping_filter.py
Normal file
@@ -0,0 +1,53 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import RewardSpace
|
||||
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
|
||||
|
||||
class RewardClippingFilter(RewardFilter):
|
||||
"""
|
||||
Clips the reward to some range
|
||||
"""
|
||||
def __init__(self, clipping_low: float=-np.inf, clipping_high: float=np.inf):
|
||||
"""
|
||||
:param clipping_low: The low threshold for reward clipping
|
||||
:param clipping_high: The high threshold for reward clipping
|
||||
"""
|
||||
super().__init__()
|
||||
self.clipping_low = clipping_low
|
||||
self.clipping_high = clipping_high
|
||||
|
||||
if clipping_low > clipping_high:
|
||||
raise ValueError("The reward clipping low must be lower than the reward clipping max")
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
reward = float(reward)
|
||||
|
||||
if self.clipping_high:
|
||||
reward = min(reward, self.clipping_high)
|
||||
if self.clipping_low:
|
||||
reward = max(reward, self.clipping_low)
|
||||
|
||||
return reward
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
input_reward_space.high = min(self.clipping_high, input_reward_space.high)
|
||||
input_reward_space.low = max(self.clipping_low, input_reward_space.low)
|
||||
return input_reward_space
|
||||
31
rl_coach/filters/reward/reward_filter.py
Normal file
31
rl_coach/filters/reward/reward_filter.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.filters.filter import Filter
|
||||
from rl_coach.spaces import RewardSpace
|
||||
|
||||
|
||||
class RewardFilter(Filter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
"""
|
||||
This function should contain the logic for getting the filtered reward space
|
||||
:param input_reward_space: the input reward space
|
||||
:return: the filtered reward space
|
||||
"""
|
||||
return input_reward_space
|
||||
68
rl_coach/filters/reward/reward_normalization_filter.py
Normal file
68
rl_coach/filters/reward/reward_normalization_filter.py
Normal file
@@ -0,0 +1,68 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
import numpy as np
|
||||
from rl_coach.spaces import RewardSpace
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.shared_variables import SharedRunningStats
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
|
||||
|
||||
class RewardNormalizationFilter(RewardFilter):
|
||||
"""
|
||||
Normalize the reward with a running standard deviation and mean of the rewards seen so far
|
||||
If there is more than a single worker, the statistics of the rewards are shared between all the workers
|
||||
"""
|
||||
def __init__(self, clip_min: float=-5.0, clip_max: float=5.0):
|
||||
"""
|
||||
:param clip_min: The minimum value to allow after normalizing the reward
|
||||
:param clip_max: The maximum value to allow after normalizing the reward
|
||||
"""
|
||||
super().__init__()
|
||||
self.clip_min = clip_min
|
||||
self.clip_max = clip_max
|
||||
self.running_rewards_stats = None
|
||||
|
||||
def set_device(self, device) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the device if it is required to use tensorflow ops
|
||||
:param device: the device to use
|
||||
:return: None
|
||||
"""
|
||||
self.running_rewards_stats = SharedRunningStats(device, name='rewards_stats')
|
||||
|
||||
def set_session(self, sess) -> None:
|
||||
"""
|
||||
An optional function that allows the filter to get the session if it is required to use tensorflow ops
|
||||
:param sess: the session
|
||||
:return: None
|
||||
"""
|
||||
self.running_rewards_stats.set_session(sess)
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
if update_internal_state:
|
||||
self.running_rewards_stats.push(reward)
|
||||
|
||||
reward = (reward - self.running_rewards_stats.mean) / \
|
||||
(self.running_rewards_stats.std + 1e-15)
|
||||
reward = np.clip(reward, self.clip_min, self.clip_max)
|
||||
|
||||
return reward
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
return input_reward_space
|
||||
44
rl_coach/filters/reward/reward_rescale_filter.py
Normal file
44
rl_coach/filters/reward/reward_rescale_filter.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from rl_coach.spaces import RewardSpace
|
||||
|
||||
from rl_coach.core_types import RewardType
|
||||
from rl_coach.filters.reward.reward_filter import RewardFilter
|
||||
|
||||
|
||||
class RewardRescaleFilter(RewardFilter):
|
||||
"""
|
||||
Rescales the reward by multiplying with some factor
|
||||
"""
|
||||
def __init__(self, rescale_factor: float):
|
||||
"""
|
||||
:param rescale_factor: The reward rescaling factor by which the reward will be multiplied
|
||||
"""
|
||||
super().__init__()
|
||||
self.rescale_factor = rescale_factor
|
||||
|
||||
if rescale_factor == 0:
|
||||
raise ValueError("The reward rescale value can not be set to 0")
|
||||
|
||||
def filter(self, reward: RewardType, update_internal_state: bool=True) -> RewardType:
|
||||
reward = float(reward) * self.rescale_factor
|
||||
return reward
|
||||
|
||||
def get_filtered_reward_space(self, input_reward_space: RewardSpace) -> RewardSpace:
|
||||
input_reward_space.high = input_reward_space.high * self.rescale_factor
|
||||
input_reward_space.low = input_reward_space.low * self.rescale_factor
|
||||
return input_reward_space
|
||||
Reference in New Issue
Block a user