From 67a90ee87e948296b3204192cf0d0065acdb421c Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Mon, 19 Nov 2018 06:41:12 -0800 Subject: [PATCH] Add tensor input type for arbitrary dimensional observation (#125) * Allow arbitrary dimensional observation (non vector or image) * Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3 --- rl_coach/architectures/embedder_parameters.py | 4 +- .../mxnet_components/embedders/__init__.py | 5 +- .../embedders/tensor_embedder.py | 53 +++++++++++++++ .../mxnet_components/general_network.py | 12 ++-- .../embedders/__init__.py | 3 +- .../embedders/tensor_embedder.py | 52 +++++++++++++++ .../tensorflow_components/general_network.py | 8 ++- rl_coach/core_types.py | 4 ++ rl_coach/environments/gym_environment.py | 65 +++++++++++++++---- rl_coach/spaces.py | 12 +++- 10 files changed, 194 insertions(+), 24 deletions(-) create mode 100644 rl_coach/architectures/mxnet_components/embedders/tensor_embedder.py create mode 100644 rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py diff --git a/rl_coach/architectures/embedder_parameters.py b/rl_coach/architectures/embedder_parameters.py index 2973a3a..679cac9 100644 --- a/rl_coach/architectures/embedder_parameters.py +++ b/rl_coach/architectures/embedder_parameters.py @@ -30,9 +30,9 @@ class InputEmbedderParameters(NetworkComponentParameters): self.dropout_rate = dropout_rate if input_rescaling is None: - input_rescaling = {'image': 255.0, 'vector': 1.0} + input_rescaling = {'image': 255.0, 'vector': 1.0, 'tensor': 1.0} if input_offset is None: - input_offset = {'image': 0.0, 'vector': 0.0} + input_offset = {'image': 0.0, 'vector': 0.0, 'tensor': 0.0} self.input_rescaling = input_rescaling self.input_offset = input_offset diff --git a/rl_coach/architectures/mxnet_components/embedders/__init__.py b/rl_coach/architectures/mxnet_components/embedders/__init__.py index eb0482f..93d79f2 100644 --- a/rl_coach/architectures/mxnet_components/embedders/__init__.py +++ b/rl_coach/architectures/mxnet_components/embedders/__init__.py @@ -1,4 +1,7 @@ from .image_embedder import ImageEmbedder +from .tensor_embedder import TensorEmbedder from .vector_embedder import VectorEmbedder -__all__ = ['ImageEmbedder', 'VectorEmbedder'] +__all__ = ['ImageEmbedder', + 'TensorEmbedder', + 'VectorEmbedder'] diff --git a/rl_coach/architectures/mxnet_components/embedders/tensor_embedder.py b/rl_coach/architectures/mxnet_components/embedders/tensor_embedder.py new file mode 100644 index 0000000..11235ec --- /dev/null +++ b/rl_coach/architectures/mxnet_components/embedders/tensor_embedder.py @@ -0,0 +1,53 @@ +from typing import Union +from types import ModuleType + +import mxnet as mx +from rl_coach.architectures.embedder_parameters import InputEmbedderParameters +from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder + +nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol] + + +class TensorEmbedder(InputEmbedder): + def __init__(self, params: InputEmbedderParameters): + """ + A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector + embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors) + or other type of data that is more than 1 dimension (i.e. not vector) but is not an image. + + NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing + a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable + object must return a Gluon HybridBlock. The hybrid_forward() of this block must accept a single input, + normalized observation, and return an embedding vector for each sample in the batch. + Keep in mind that the scheme is a list of blocks, which are stacked by optional batchnorm, + activation, and dropout in between as specified in InputEmbedderParameters. + + :param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function + and dropout properties. + """ + super(TensorEmbedder, self).__init__(params) + self.input_rescaling = params.input_rescaling['tensor'] + self.input_offset = params.input_offset['tensor'] + + @property + def schemes(self) -> dict: + """ + Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used + to create Block when InputEmbedder is initialised. + + Note: Tensor embedder doesn't define any pre-defined scheme. User must provide custom scheme in preset. + + :return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block. + For tensor embedder, this is an empty dictionary. + """ + return {} + + def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type: + """ + Used for forward pass through embedder network. + + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: image representing environment state, of shape (batch_size, in_channels, height, width). + :return: embedding of environment state, of shape (batch_size, channels). + """ + return super(TensorEmbedder, self).hybrid_forward(F, x, *args, **kwargs) diff --git a/rl_coach/architectures/mxnet_components/general_network.py b/rl_coach/architectures/mxnet_components/general_network.py index 99645fe..8a856f3 100644 --- a/rl_coach/architectures/mxnet_components/general_network.py +++ b/rl_coach/architectures/mxnet_components/general_network.py @@ -33,12 +33,12 @@ from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadPara from rl_coach.architectures.middleware_parameters import MiddlewareParameters from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture -from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder +from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware from rl_coach.architectures.mxnet_components import utils from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType -from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace +from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace class GeneralMxnetNetwork(MxnetArchitecture): @@ -172,7 +172,9 @@ def _get_input_embedder(spaces: SpacesDefinition, .format(input_name, allowed_inputs.keys())) type = "vector" - if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): + if isinstance(allowed_inputs[input_name], TensorObservationSpace): + type = "tensor" + elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): type = "image" def sanitize_params(params: InputEmbedderParameters): @@ -187,8 +189,10 @@ def _get_input_embedder(spaces: SpacesDefinition, module = VectorEmbedder(embedder_params) elif type == 'image': module = ImageEmbedder(embedder_params) + elif type == 'tensor': + module = TensorEmbedder(embedder_params) else: - raise KeyError('Unsupported embedder type: {}'.format(type)) + raise KeyError('Unsupported embedder type: {}'.format(type)) return module diff --git a/rl_coach/architectures/tensorflow_components/embedders/__init__.py b/rl_coach/architectures/tensorflow_components/embedders/__init__.py index eb0482f..5091f35 100644 --- a/rl_coach/architectures/tensorflow_components/embedders/__init__.py +++ b/rl_coach/architectures/tensorflow_components/embedders/__init__.py @@ -1,4 +1,5 @@ from .image_embedder import ImageEmbedder from .vector_embedder import VectorEmbedder +from .tensor_embedder import TensorEmbedder -__all__ = ['ImageEmbedder', 'VectorEmbedder'] +__all__ = ['ImageEmbedder', 'VectorEmbedder', 'TensorEmbedder'] diff --git a/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py b/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py new file mode 100644 index 0000000..286442c --- /dev/null +++ b/rl_coach/architectures/tensorflow_components/embedders/tensor_embedder.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2017 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import List + +import tensorflow as tf + +from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense +from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder +from rl_coach.base_parameters import EmbedderScheme +from rl_coach.core_types import InputTensorEmbedding + + +class TensorEmbedder(InputEmbedder): + """ + A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector + embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors) + or other type of data that is more than 1 dimension (i.e. not vector) but is not an image. + + NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing + a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable + object must accept a single input, the normalized observation, and return a Tensorflow symbol which + will calculate an embedding vector for each sample in the batch. + Keep in mind that the scheme is a list of Tensorflow symbols, which are stacked by optional batchnorm, + activation, and dropout in between as specified in InputEmbedderParameters. + """ + + def __init__(self, input_size: List[int], activation_function=tf.nn.relu, + scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0, + name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None, + dense_layer=Dense, is_training=False): + super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling, + input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training) + self.return_type = InputTensorEmbedding + assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder" + + @property + def schemes(self): + return {} diff --git a/rl_coach/architectures/tensorflow_components/general_network.py b/rl_coach/architectures/tensorflow_components/general_network.py index b7856aa..28b9c60 100644 --- a/rl_coach/architectures/tensorflow_components/general_network.py +++ b/rl_coach/architectures/tensorflow_components/general_network.py @@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.architecture import TensorFlow from rl_coach.architectures.tensorflow_components import utils from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType from rl_coach.core_types import PredictionType -from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace +from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string @@ -116,10 +116,12 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture): raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}" .format(input_name, allowed_inputs.keys())) - mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder'} + mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'TensorEmbedder'} emb_type = "vector" - if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): + if isinstance(allowed_inputs[input_name], TensorObservationSpace): + emb_type = "tensor" + elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace): emb_type = "image" embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type] diff --git a/rl_coach/core_types.py b/rl_coach/core_types.py index 5f5b4f1..d615b48 100644 --- a/rl_coach/core_types.py +++ b/rl_coach/core_types.py @@ -114,6 +114,10 @@ class InputVectorEmbedding(InputEmbedding): pass +class InputTensorEmbedding(InputEmbedding): + pass + + class Middleware_FC_Embedding(MiddlewareEmbedding): pass diff --git a/rl_coach/environments/gym_environment.py b/rl_coach/environments/gym_environment.py index 43ddaef..adf096d 100644 --- a/rl_coach/environments/gym_environment.py +++ b/rl_coach/environments/gym_environment.py @@ -16,6 +16,7 @@ import gym import numpy as np +from enum import IntEnum import scipy.ndimage from rl_coach.graph_managers.graph_manager import ScheduleParameters @@ -44,7 +45,7 @@ from typing import Dict, Any, Union from rl_coach.core_types import RunPhase, EnvironmentSteps from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \ - StateSpace, RewardSpace + PlanarMapsObservationSpace, TensorObservationSpace, StateSpace, RewardSpace from rl_coach.filters.filter import NoInputFilter, NoOutputFilter from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter @@ -176,11 +177,26 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper): # Environment +class ObservationSpaceType(IntEnum): + Tensor = 0 + Image = 1 + Vector = 2 + + class GymEnvironment(Environment): - def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters, - target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None, - human_control: bool=False, custom_reward_threshold: Union[int, float]=None, - random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs): + def __init__(self, + level: LevelSelection, + frame_skip: int, + visualization_parameters: VisualizationParameters, + target_success_rate: float=1.0, + additional_simulator_parameters: Dict[str, Any] = {}, + seed: Union[None, int] = None, + human_control: bool=False, + custom_reward_threshold: Union[int, float]=None, + random_initialization_steps: int=1, + max_over_num_frames: int=1, + observation_space_type: ObservationSpaceType=None, + **kwargs): """ :param level: (str) A string representing the gym level to run. This can also be a LevelSelection object. @@ -215,6 +231,11 @@ class GymEnvironment(Environment): This value will be used for merging multiple frames into a single frame by taking the maximum value for each of the pixels in the frame. This is particularly used in Atari games, where the frames flicker, and objects can be seen in one frame but disappear in the next. + + :param observation_space_type: + This value will be used for generating observation space. Allows a custom space. Should be one of + ObservationSpaceType. If not specified, observation space is inferred from the number of dimensions + of the observation: 1D: Vector space, 3D: Image space if 1 or 3 channels, PlanarMaps space otherwise. """ super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold, visualization_parameters, target_success_rate) @@ -305,20 +326,40 @@ class GymEnvironment(Environment): state_space = self.env.observation_space.spaces for observation_space_name, observation_space in state_space.items(): - if len(observation_space.shape) == 3: + if observation_space_type == ObservationSpaceType.Tensor: + # we consider arbitrary input tensor which does not necessarily represent images + self.state_space[observation_space_name] = TensorObservationSpace( + shape=np.array(observation_space.shape), + low=observation_space.low, + high=observation_space.high + ) + elif observation_space_type == ObservationSpaceType.Image or len(observation_space.shape) == 3: # we assume gym has image observations (with arbitrary number of channels) where their values are # within 0-255, and where the channel dimension is the last dimension - self.state_space[observation_space_name] = ImageObservationSpace( - shape=np.array(observation_space.shape), - high=255, - channels_axis=-1 - ) - else: + if observation_space.shape[-1] in [1, 3]: + self.state_space[observation_space_name] = ImageObservationSpace( + shape=np.array(observation_space.shape), + high=255, + channels_axis=-1 + ) + else: + # For any number of channels other than 1 or 3, use the generic PlanarMaps space + self.state_space[observation_space_name] = PlanarMapsObservationSpace( + shape=np.array(observation_space.shape), + low=0, + high=255, + channels_axis=-1 + ) + elif observation_space_type == ObservationSpaceType.Vector or len(observation_space.shape) == 1: self.state_space[observation_space_name] = VectorObservationSpace( shape=observation_space.shape[0], low=observation_space.low, high=observation_space.high ) + else: + raise screen.error("Failed to instantiate Gym environment class %s with observation space type %s" % + (env_class, observation_space_type), crash=True) + if 'desired_goal' in state_space.keys(): self.goal_space = self.state_space['desired_goal'] diff --git a/rl_coach/spaces.py b/rl_coach/spaces.py index c3c61e9..7db95a3 100644 --- a/rl_coach/spaces.py +++ b/rl_coach/spaces.py @@ -183,7 +183,7 @@ class ObservationSpace(Space): class VectorObservationSpace(ObservationSpace): """ An observation space which is defined as a vector of elements. This can be particularly useful for environments - which return measurements, such as in robotic environmnets. + which return measurements, such as in robotic environments. """ def __init__(self, shape: int, low: Union[None, int, float, np.ndarray]=-np.inf, high: Union[None, int, float, np.ndarray]=np.inf, measurements_names: List[str]=None): @@ -197,6 +197,16 @@ class VectorObservationSpace(ObservationSpace): super().__init__(shape, low, high) +class TensorObservationSpace(ObservationSpace): + """ + An observation space which defines observations with arbitrary shape. This can be particularly useful for + environments with non image input. + """ + def __init__(self, shape: np.ndarray, low: -np.inf, + high: np.inf): + super().__init__(shape, low, high) + + class PlanarMapsObservationSpace(ObservationSpace): """ An observation space which defines a stack of 2D observations. For example, an environment which returns