1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-17 19:20:19 +01:00

Add tensor input type for arbitrary dimensional observation (#125)

* Allow arbitrary dimensional observation (non vector or image)
* Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
Sina Afrooze
2018-11-19 06:41:12 -08:00
committed by Gal Leibovich
parent 7ba1a4393f
commit 67a90ee87e
10 changed files with 194 additions and 24 deletions

View File

@@ -30,9 +30,9 @@ class InputEmbedderParameters(NetworkComponentParameters):
self.dropout_rate = dropout_rate
if input_rescaling is None:
input_rescaling = {'image': 255.0, 'vector': 1.0}
input_rescaling = {'image': 255.0, 'vector': 1.0, 'tensor': 1.0}
if input_offset is None:
input_offset = {'image': 0.0, 'vector': 0.0}
input_offset = {'image': 0.0, 'vector': 0.0, 'tensor': 0.0}
self.input_rescaling = input_rescaling
self.input_offset = input_offset

View File

@@ -1,4 +1,7 @@
from .image_embedder import ImageEmbedder
from .tensor_embedder import TensorEmbedder
from .vector_embedder import VectorEmbedder
__all__ = ['ImageEmbedder', 'VectorEmbedder']
__all__ = ['ImageEmbedder',
'TensorEmbedder',
'VectorEmbedder']

View File

@@ -0,0 +1,53 @@
from typing import Union
from types import ModuleType
import mxnet as mx
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
class TensorEmbedder(InputEmbedder):
def __init__(self, params: InputEmbedderParameters):
"""
A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector
embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors)
or other type of data that is more than 1 dimension (i.e. not vector) but is not an image.
NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing
a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable
object must return a Gluon HybridBlock. The hybrid_forward() of this block must accept a single input,
normalized observation, and return an embedding vector for each sample in the batch.
Keep in mind that the scheme is a list of blocks, which are stacked by optional batchnorm,
activation, and dropout in between as specified in InputEmbedderParameters.
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
and dropout properties.
"""
super(TensorEmbedder, self).__init__(params)
self.input_rescaling = params.input_rescaling['tensor']
self.input_offset = params.input_offset['tensor']
@property
def schemes(self) -> dict:
"""
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
to create Block when InputEmbedder is initialised.
Note: Tensor embedder doesn't define any pre-defined scheme. User must provide custom scheme in preset.
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
For tensor embedder, this is an empty dictionary.
"""
return {}
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
"""
Used for forward pass through embedder network.
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
:return: embedding of environment state, of shape (batch_size, channels).
"""
return super(TensorEmbedder, self).hybrid_forward(F, x, *args, **kwargs)

View File

@@ -33,12 +33,12 @@ from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadPara
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
class GeneralMxnetNetwork(MxnetArchitecture):
@@ -172,7 +172,9 @@ def _get_input_embedder(spaces: SpacesDefinition,
.format(input_name, allowed_inputs.keys()))
type = "vector"
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
type = "tensor"
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
type = "image"
def sanitize_params(params: InputEmbedderParameters):
@@ -187,8 +189,10 @@ def _get_input_embedder(spaces: SpacesDefinition,
module = VectorEmbedder(embedder_params)
elif type == 'image':
module = ImageEmbedder(embedder_params)
elif type == 'tensor':
module = TensorEmbedder(embedder_params)
else:
raise KeyError('Unsupported embedder type: {}'.format(type))
raise KeyError('Unsupported embedder type: {}'.format(type))
return module

View File

@@ -1,4 +1,5 @@
from .image_embedder import ImageEmbedder
from .vector_embedder import VectorEmbedder
from .tensor_embedder import TensorEmbedder
__all__ = ['ImageEmbedder', 'VectorEmbedder']
__all__ = ['ImageEmbedder', 'VectorEmbedder', 'TensorEmbedder']

View File

@@ -0,0 +1,52 @@
#
# Copyright (c) 2017 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import List
import tensorflow as tf
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.core_types import InputTensorEmbedding
class TensorEmbedder(InputEmbedder):
"""
A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector
embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors)
or other type of data that is more than 1 dimension (i.e. not vector) but is not an image.
NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing
a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable
object must accept a single input, the normalized observation, and return a Tensorflow symbol which
will calculate an embedding vector for each sample in the batch.
Keep in mind that the scheme is a list of Tensorflow symbols, which are stacked by optional batchnorm,
activation, and dropout in between as specified in InputEmbedderParameters.
"""
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
dense_layer=Dense, is_training=False):
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
self.return_type = InputTensorEmbedding
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"
@property
def schemes(self):
return {}

View File

@@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.architecture import TensorFlow
from rl_coach.architectures.tensorflow_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.core_types import PredictionType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
@@ -116,10 +116,12 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
.format(input_name, allowed_inputs.keys()))
mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder'}
mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'TensorEmbedder'}
emb_type = "vector"
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
emb_type = "tensor"
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
emb_type = "image"
embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type]

View File

@@ -114,6 +114,10 @@ class InputVectorEmbedding(InputEmbedding):
pass
class InputTensorEmbedding(InputEmbedding):
pass
class Middleware_FC_Embedding(MiddlewareEmbedding):
pass

View File

@@ -16,6 +16,7 @@
import gym
import numpy as np
from enum import IntEnum
import scipy.ndimage
from rl_coach.graph_managers.graph_manager import ScheduleParameters
@@ -44,7 +45,7 @@ from typing import Dict, Any, Union
from rl_coach.core_types import RunPhase, EnvironmentSteps
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
StateSpace, RewardSpace
PlanarMapsObservationSpace, TensorObservationSpace, StateSpace, RewardSpace
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
@@ -176,11 +177,26 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
# Environment
class ObservationSpaceType(IntEnum):
Tensor = 0
Image = 1
Vector = 2
class GymEnvironment(Environment):
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
def __init__(self,
level: LevelSelection,
frame_skip: int,
visualization_parameters: VisualizationParameters,
target_success_rate: float=1.0,
additional_simulator_parameters: Dict[str, Any] = {},
seed: Union[None, int] = None,
human_control: bool=False,
custom_reward_threshold: Union[int, float]=None,
random_initialization_steps: int=1,
max_over_num_frames: int=1,
observation_space_type: ObservationSpaceType=None,
**kwargs):
"""
:param level: (str)
A string representing the gym level to run. This can also be a LevelSelection object.
@@ -215,6 +231,11 @@ class GymEnvironment(Environment):
This value will be used for merging multiple frames into a single frame by taking the maximum value for each
of the pixels in the frame. This is particularly used in Atari games, where the frames flicker, and objects
can be seen in one frame but disappear in the next.
:param observation_space_type:
This value will be used for generating observation space. Allows a custom space. Should be one of
ObservationSpaceType. If not specified, observation space is inferred from the number of dimensions
of the observation: 1D: Vector space, 3D: Image space if 1 or 3 channels, PlanarMaps space otherwise.
"""
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
visualization_parameters, target_success_rate)
@@ -305,20 +326,40 @@ class GymEnvironment(Environment):
state_space = self.env.observation_space.spaces
for observation_space_name, observation_space in state_space.items():
if len(observation_space.shape) == 3:
if observation_space_type == ObservationSpaceType.Tensor:
# we consider arbitrary input tensor which does not necessarily represent images
self.state_space[observation_space_name] = TensorObservationSpace(
shape=np.array(observation_space.shape),
low=observation_space.low,
high=observation_space.high
)
elif observation_space_type == ObservationSpaceType.Image or len(observation_space.shape) == 3:
# we assume gym has image observations (with arbitrary number of channels) where their values are
# within 0-255, and where the channel dimension is the last dimension
self.state_space[observation_space_name] = ImageObservationSpace(
shape=np.array(observation_space.shape),
high=255,
channels_axis=-1
)
else:
if observation_space.shape[-1] in [1, 3]:
self.state_space[observation_space_name] = ImageObservationSpace(
shape=np.array(observation_space.shape),
high=255,
channels_axis=-1
)
else:
# For any number of channels other than 1 or 3, use the generic PlanarMaps space
self.state_space[observation_space_name] = PlanarMapsObservationSpace(
shape=np.array(observation_space.shape),
low=0,
high=255,
channels_axis=-1
)
elif observation_space_type == ObservationSpaceType.Vector or len(observation_space.shape) == 1:
self.state_space[observation_space_name] = VectorObservationSpace(
shape=observation_space.shape[0],
low=observation_space.low,
high=observation_space.high
)
else:
raise screen.error("Failed to instantiate Gym environment class %s with observation space type %s" %
(env_class, observation_space_type), crash=True)
if 'desired_goal' in state_space.keys():
self.goal_space = self.state_space['desired_goal']

View File

@@ -183,7 +183,7 @@ class ObservationSpace(Space):
class VectorObservationSpace(ObservationSpace):
"""
An observation space which is defined as a vector of elements. This can be particularly useful for environments
which return measurements, such as in robotic environmnets.
which return measurements, such as in robotic environments.
"""
def __init__(self, shape: int, low: Union[None, int, float, np.ndarray]=-np.inf,
high: Union[None, int, float, np.ndarray]=np.inf, measurements_names: List[str]=None):
@@ -197,6 +197,16 @@ class VectorObservationSpace(ObservationSpace):
super().__init__(shape, low, high)
class TensorObservationSpace(ObservationSpace):
"""
An observation space which defines observations with arbitrary shape. This can be particularly useful for
environments with non image input.
"""
def __init__(self, shape: np.ndarray, low: -np.inf,
high: np.inf):
super().__init__(shape, low, high)
class PlanarMapsObservationSpace(ObservationSpace):
"""
An observation space which defines a stack of 2D observations. For example, an environment which returns