mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Add tensor input type for arbitrary dimensional observation (#125)
* Allow arbitrary dimensional observation (non vector or image) * Added creating PlanarMapsObservationSpace to GymEnvironment when number of channels is not 1 or 3
This commit is contained in:
committed by
Gal Leibovich
parent
7ba1a4393f
commit
67a90ee87e
@@ -33,12 +33,12 @@ from rl_coach.architectures.head_parameters import PPOVHeadParameters, VHeadPara
|
||||
from rl_coach.architectures.middleware_parameters import MiddlewareParameters
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters, LSTMMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.architecture import MxnetArchitecture
|
||||
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, VectorEmbedder
|
||||
from rl_coach.architectures.mxnet_components.embedders import ImageEmbedder, TensorEmbedder, VectorEmbedder
|
||||
from rl_coach.architectures.mxnet_components.heads import Head, HeadLoss, PPOHead, PPOVHead, VHead, QHead
|
||||
from rl_coach.architectures.mxnet_components.middlewares import FCMiddleware, LSTMMiddleware
|
||||
from rl_coach.architectures.mxnet_components import utils
|
||||
from rl_coach.base_parameters import AgentParameters, EmbeddingMergerType
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace
|
||||
from rl_coach.spaces import SpacesDefinition, PlanarMapsObservationSpace, TensorObservationSpace
|
||||
|
||||
|
||||
class GeneralMxnetNetwork(MxnetArchitecture):
|
||||
@@ -172,7 +172,9 @@ def _get_input_embedder(spaces: SpacesDefinition,
|
||||
.format(input_name, allowed_inputs.keys()))
|
||||
|
||||
type = "vector"
|
||||
if isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
if isinstance(allowed_inputs[input_name], TensorObservationSpace):
|
||||
type = "tensor"
|
||||
elif isinstance(allowed_inputs[input_name], PlanarMapsObservationSpace):
|
||||
type = "image"
|
||||
|
||||
def sanitize_params(params: InputEmbedderParameters):
|
||||
@@ -187,8 +189,10 @@ def _get_input_embedder(spaces: SpacesDefinition,
|
||||
module = VectorEmbedder(embedder_params)
|
||||
elif type == 'image':
|
||||
module = ImageEmbedder(embedder_params)
|
||||
elif type == 'tensor':
|
||||
module = TensorEmbedder(embedder_params)
|
||||
else:
|
||||
raise KeyError('Unsupported embedder type: {}'.format(type))
|
||||
raise KeyError('Unsupported embedder type: {}'.format(type))
|
||||
return module
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user