1
0
mirror of https://github.com/gryf/coach.git synced 2025-12-18 11:40:18 +01:00

Add tensor input type for arbitrary dimensional observation (#125)

* Allow arbitrary dimensional observation (non vector or image)
* Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
Sina Afrooze
2018-11-19 06:41:12 -08:00
committed by Gal Leibovich
parent 7ba1a4393f
commit 67a90ee87e
10 changed files with 194 additions and 24 deletions

View File

@@ -33,12 +33,12 @@ from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadPara
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
from rl_coach.architectures.mxnet_components import utils
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
class GeneralMxnetNetwork(MxnetArchitecture):
@@ -172,7 +172,9 @@ def _get_input_embedder(spaces: SpacesDefinition,
.format(input_name, allowed_inputs.keys()))
type = "vector"
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
type = "tensor"
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
type = "image"
def sanitize_params(params: InputEmbedderParameters):
@@ -187,8 +189,10 @@ def _get_input_embedder(spaces: SpacesDefinition,
module = VectorEmbedder(embedder_params)
elif type == 'image':
module = ImageEmbedder(embedder_params)
elif type == 'tensor':
module = TensorEmbedder(embedder_params)
else:
raise KeyError('Unsupported embedder type: {}'.format(type))
raise KeyError('Unsupported embedder type: {}'.format(type))
return module