mirror of
https://github.com/gryf/coach.git
synced 2026-03-27 05:33:32 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.image_embedder import ImageEmbedder
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_image_embedder():
|
||||
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
||||
emb = ImageEmbedder(params=params)
|
||||
emb.initialize()
|
||||
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
|
||||
output = emb(input_data)
|
||||
assert len(output.shape) == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
@@ -0,0 +1,22 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.vector_embedder import VectorEmbedder
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_vector_embedder():
|
||||
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
||||
emb = VectorEmbedder(params=params)
|
||||
emb.initialize()
|
||||
input_data = mx.nd.random.uniform(low=0, high=255, shape=(10, 100))
|
||||
output = emb(input_data)
|
||||
assert len(output.shape) == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
assert output.shape[1] == 256 # since last dense layer has 256 units
|
||||
Reference in New Issue
Block a user