mirror of
https://github.com/gryf/coach.git
synced 2025-12-17 19:20:19 +01:00
Add tensor input type for arbitrary dimensional observation (#125)
* Allow arbitrary dimensional observation (non vector or image) * Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
committed by
Gal Leibovich
parent
7ba1a4393f
commit
67a90ee87e
@@ -30,9 +30,9 @@ class InputEmbedderParameters(NetworkComponentParameters):
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if input_rescaling is None:
|
||||
input_rescaling = {'image': 255.0, 'vector': 1.0}
|
||||
input_rescaling = {'image': 255.0, 'vector': 1.0, 'tensor': 1.0}
|
||||
if input_offset is None:
|
||||
input_offset = {'image': 0.0, 'vector': 0.0}
|
||||
input_offset = {'image': 0.0, 'vector': 0.0, 'tensor': 0.0}
|
||||
|
||||
self.input_rescaling = input_rescaling
|
||||
self.input_offset = input_offset
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from .image_embedder import ImageEmbedder
|
||||
from .tensor_embedder import TensorEmbedder
|
||||
from .vector_embedder import VectorEmbedder
|
||||
|
||||
__all__ = ['ImageEmbedder', 'VectorEmbedder']
|
||||
__all__ = ['ImageEmbedder',
|
||||
'TensorEmbedder',
|
||||
'VectorEmbedder']
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Union
|
||||
from types import ModuleType
|
||||
|
||||
import mxnet as mx
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.embedder import InputEmbedder
|
||||
|
||||
nd_sym_type = Union[mx.nd.NDArray, mx.sym.Symbol]
|
||||
|
||||
|
||||
class TensorEmbedder(InputEmbedder):
|
||||
def __init__(self, params: InputEmbedderParameters):
|
||||
"""
|
||||
A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector
|
||||
embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors)
|
||||
or other type of data that is more than 1 dimension (i.e. not vector) but is not an image.
|
||||
|
||||
NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing
|
||||
a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable
|
||||
object must return a Gluon HybridBlock. The hybrid_forward() of this block must accept a single input,
|
||||
normalized observation, and return an embedding vector for each sample in the batch.
|
||||
Keep in mind that the scheme is a list of blocks, which are stacked by optional batchnorm,
|
||||
activation, and dropout in between as specified in InputEmbedderParameters.
|
||||
|
||||
:param params: parameters object containing input_clipping, input_rescaling, batchnorm, activation_function
|
||||
and dropout properties.
|
||||
"""
|
||||
super(TensorEmbedder, self).__init__(params)
|
||||
self.input_rescaling = params.input_rescaling['tensor']
|
||||
self.input_offset = params.input_offset['tensor']
|
||||
|
||||
@property
|
||||
def schemes(self) -> dict:
|
||||
"""
|
||||
Schemes are the pre-defined network architectures of various depths and complexities that can be used. Are used
|
||||
to create Block when InputEmbedder is initialised.
|
||||
|
||||
Note: Tensor embedder doesn't define any pre-defined scheme. User must provide custom scheme in preset.
|
||||
|
||||
:return: dictionary of schemes, with key of type EmbedderScheme enum and value being list of mxnet.gluon.Block.
|
||||
For tensor embedder, this is an empty dictionary.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_sym_type:
|
||||
"""
|
||||
Used for forward pass through embedder network.
|
||||
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
|
||||
:return: embedding of environment state, of shape (batch_size, channels).
|
||||
"""
|
||||
return super(TensorEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
|
||||
@@ -33,12 +33,12 @@ from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadPara
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
|
||||
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
|
||||
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder
|
||||
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
||||
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||
|
||||
|
||||
class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
@@ -172,7 +172,9 @@ def _get_input_embedder(spaces: SpacesDefinition,
|
||||
.format(input_name, allowed_inputs.keys()))
|
||||
|
||||
type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
|
||||
type = "tensor"
|
||||
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
type = "image"
|
||||
|
||||
def sanitize_params(params: InputEmbedderParameters):
|
||||
@@ -187,8 +189,10 @@ def _get_input_embedder(spaces: SpacesDefinition,
|
||||
module = VectorEmbedder(embedder_params)
|
||||
elif type == 'image':
|
||||
module = ImageEmbedder(embedder_params)
|
||||
elif type == 'tensor':
|
||||
module = TensorEmbedder(embedder_params)
|
||||
else:
|
||||
raise KeyError('Unsupported embedder type: {}'.format(type))
|
||||
raise KeyError('Unsupported embedder type: {}'.format(type))
|
||||
return module
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from .image_embedder import ImageEmbedder
|
||||
from .vector_embedder import VectorEmbedder
|
||||
from .tensor_embedder import TensorEmbedder
|
||||
|
||||
__all__ = ['ImageEmbedder', 'VectorEmbedder']
|
||||
__all__ = ['ImageEmbedder', 'VectorEmbedder', 'TensorEmbedder']
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
#
|
||||
# Copyright (c) 2017 Intel Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from rl_coach.architectures.tensorflow_components.layers import Conv2d, Dense
|
||||
from rl_coach.architectures.tensorflow_components.embedders.embedder import InputEmbedder
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
from rl_coach.core_types import InputTensorEmbedding
|
||||
|
||||
|
||||
class TensorEmbedder(InputEmbedder):
|
||||
"""
|
||||
A tensor embedder is an input embedder that takes a tensor with arbitrary dimension and produces a vector
|
||||
embedding by passing it through a neural network. An example is video data or 3D image data (i.e. 4D tensors)
|
||||
or other type of data that is more than 1 dimension (i.e. not vector) but is not an image.
|
||||
|
||||
NOTE: There are no pre-defined schemes for tensor embedder. User must define a custom scheme by passing
|
||||
a callable object as InputEmbedderParameters.scheme when defining the respective preset. This callable
|
||||
object must accept a single input, the normalized observation, and return a Tensorflow symbol which
|
||||
will calculate an embedding vector for each sample in the batch.
|
||||
Keep in mind that the scheme is a list of Tensorflow symbols, which are stacked by optional batchnorm,
|
||||
activation, and dropout in between as specified in InputEmbedderParameters.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: List[int], activation_function=tf.nn.relu,
|
||||
scheme: EmbedderScheme=None, batchnorm: bool=False, dropout_rate: float=0.0,
|
||||
name: str= "embedder", input_rescaling: float=1.0, input_offset: float=0.0, input_clipping=None,
|
||||
dense_layer=Dense, is_training=False):
|
||||
super().__init__(input_size, activation_function, scheme, batchnorm, dropout_rate, name, input_rescaling,
|
||||
input_offset, input_clipping, dense_layer=dense_layer, is_training=is_training)
|
||||
self.return_type = InputTensorEmbedding
|
||||
assert scheme is not None, "Custom scheme (a list of callables) must be specified for TensorEmbedder"
|
||||
|
||||
@property
|
||||
def schemes(self):
|
||||
return {}
|
||||
@@ -27,7 +27,7 @@ from rl_coach.architectures.tensorflow_components.architecture import TensorFlow
|
||||
from rl_coach.architectures.tensorflow_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.core_types import PredictionType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||
from rl_coach.utils import get_all_subclasses, dynamic_import_and_instantiate_module_from_params, indent_string
|
||||
|
||||
|
||||
@@ -116,10 +116,12 @@ class GeneralTensorFlowNetwork(TensorFlowArchitecture):
|
||||
raise ValueError("The key for the input embedder ({}) must match one of the following keys: {}"
|
||||
.format(input_name, allowed_inputs.keys()))
|
||||
|
||||
mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder'}
|
||||
mod_names = {'image': 'ImageEmbedder', 'vector': 'VectorEmbedder', 'tensor': 'TensorEmbedder'}
|
||||
|
||||
emb_type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
|
||||
emb_type = "tensor"
|
||||
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
emb_type = "image"
|
||||
|
||||
embedder_path = 'rl_coach.architectures.tensorflow_components.embedders:' + mod_names[emb_type]
|
||||
|
||||
@@ -114,6 +114,10 @@ class InputVectorEmbedding(InputEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class InputTensorEmbedding(InputEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
class Middleware_FC_Embedding(MiddlewareEmbedding):
|
||||
pass
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
from enum import IntEnum
|
||||
import scipy.ndimage
|
||||
|
||||
from rl_coach.graph_managers.graph_manager import ScheduleParameters
|
||||
@@ -44,7 +45,7 @@ from typing import Dict, Any, Union
|
||||
from rl_coach.core_types import RunPhase, EnvironmentSteps
|
||||
from rl_coach.environments.environment import Environment, EnvironmentParameters, LevelSelection
|
||||
from rl_coach.spaces import DiscreteActionSpace, BoxActionSpace, ImageObservationSpace, VectorObservationSpace, \
|
||||
StateSpace, RewardSpace
|
||||
PlanarMapsObservationSpace, TensorObservationSpace, StateSpace, RewardSpace
|
||||
from rl_coach.filters.filter import NoInputFilter, NoOutputFilter
|
||||
from rl_coach.filters.reward.reward_clipping_filter import RewardClippingFilter
|
||||
from rl_coach.filters.observation.observation_rescale_to_size_filter import ObservationRescaleToSizeFilter
|
||||
@@ -176,11 +177,26 @@ class MaxOverFramesAndFrameskipEnvWrapper(gym.Wrapper):
|
||||
|
||||
|
||||
# Environment
|
||||
class ObservationSpaceType(IntEnum):
|
||||
Tensor = 0
|
||||
Image = 1
|
||||
Vector = 2
|
||||
|
||||
|
||||
class GymEnvironment(Environment):
|
||||
def __init__(self, level: LevelSelection, frame_skip: int, visualization_parameters: VisualizationParameters,
|
||||
target_success_rate: float=1.0, additional_simulator_parameters: Dict[str, Any] = {}, seed: Union[None, int]=None,
|
||||
human_control: bool=False, custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1, max_over_num_frames: int=1, **kwargs):
|
||||
def __init__(self,
|
||||
level: LevelSelection,
|
||||
frame_skip: int,
|
||||
visualization_parameters: VisualizationParameters,
|
||||
target_success_rate: float=1.0,
|
||||
additional_simulator_parameters: Dict[str, Any] = {},
|
||||
seed: Union[None, int] = None,
|
||||
human_control: bool=False,
|
||||
custom_reward_threshold: Union[int, float]=None,
|
||||
random_initialization_steps: int=1,
|
||||
max_over_num_frames: int=1,
|
||||
observation_space_type: ObservationSpaceType=None,
|
||||
**kwargs):
|
||||
"""
|
||||
:param level: (str)
|
||||
A string representing the gym level to run. This can also be a LevelSelection object.
|
||||
@@ -215,6 +231,11 @@ class GymEnvironment(Environment):
|
||||
This value will be used for merging multiple frames into a single frame by taking the maximum value for each
|
||||
of the pixels in the frame. This is particularly used in Atari games, where the frames flicker, and objects
|
||||
can be seen in one frame but disappear in the next.
|
||||
|
||||
:param observation_space_type:
|
||||
This value will be used for generating observation space. Allows a custom space. Should be one of
|
||||
ObservationSpaceType. If not specified, observation space is inferred from the number of dimensions
|
||||
of the observation: 1D: Vector space, 3D: Image space if 1 or 3 channels, PlanarMaps space otherwise.
|
||||
"""
|
||||
super().__init__(level, seed, frame_skip, human_control, custom_reward_threshold,
|
||||
visualization_parameters, target_success_rate)
|
||||
@@ -305,20 +326,40 @@ class GymEnvironment(Environment):
|
||||
state_space = self.env.observation_space.spaces
|
||||
|
||||
for observation_space_name, observation_space in state_space.items():
|
||||
if len(observation_space.shape) == 3:
|
||||
if observation_space_type == ObservationSpaceType.Tensor:
|
||||
# we consider arbitrary input tensor which does not necessarily represent images
|
||||
self.state_space[observation_space_name] = TensorObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
low=observation_space.low,
|
||||
high=observation_space.high
|
||||
)
|
||||
elif observation_space_type == ObservationSpaceType.Image or len(observation_space.shape) == 3:
|
||||
# we assume gym has image observations (with arbitrary number of channels) where their values are
|
||||
# within 0-255, and where the channel dimension is the last dimension
|
||||
self.state_space[observation_space_name] = ImageObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
else:
|
||||
if observation_space.shape[-1] in [1, 3]:
|
||||
self.state_space[observation_space_name] = ImageObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
else:
|
||||
# For any number of channels other than 1 or 3, use the generic PlanarMaps space
|
||||
self.state_space[observation_space_name] = PlanarMapsObservationSpace(
|
||||
shape=np.array(observation_space.shape),
|
||||
low=0,
|
||||
high=255,
|
||||
channels_axis=-1
|
||||
)
|
||||
elif observation_space_type == ObservationSpaceType.Vector or len(observation_space.shape) == 1:
|
||||
self.state_space[observation_space_name] = VectorObservationSpace(
|
||||
shape=observation_space.shape[0],
|
||||
low=observation_space.low,
|
||||
high=observation_space.high
|
||||
)
|
||||
else:
|
||||
raise screen.error("Failed to instantiate Gym environment class %s with observation space type %s" %
|
||||
(env_class, observation_space_type), crash=True)
|
||||
|
||||
if 'desired_goal' in state_space.keys():
|
||||
self.goal_space = self.state_space['desired_goal']
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ class ObservationSpace(Space):
|
||||
class VectorObservationSpace(ObservationSpace):
|
||||
"""
|
||||
An observation space which is defined as a vector of elements. This can be particularly useful for environments
|
||||
which return measurements, such as in robotic environmnets.
|
||||
which return measurements, such as in robotic environments.
|
||||
"""
|
||||
def __init__(self, shape: int, low: Union[None, int, float, np.ndarray]=-np.inf,
|
||||
high: Union[None, int, float, np.ndarray]=np.inf, measurements_names: List[str]=None):
|
||||
@@ -197,6 +197,16 @@ class VectorObservationSpace(ObservationSpace):
|
||||
super().__init__(shape, low, high)
|
||||
|
||||
|
||||
class TensorObservationSpace(ObservationSpace):
|
||||
"""
|
||||
An observation space which defines observations with arbitrary shape. This can be particularly useful for
|
||||
environments with non image input.
|
||||
"""
|
||||
def __init__(self, shape: np.ndarray, low: -np.inf,
|
||||
high: np.inf):
|
||||
super().__init__(shape, low, high)
|
||||
|
||||
|
||||
class PlanarMapsObservationSpace(ObservationSpace):
|
||||
"""
|
||||
An observation space which defines a stack of 2D observations. For example, an environment which returns
|
||||
|
||||
Reference in New Issue
Block a user