1
0
mirror of https://github.com/gryf/coach.git synced 2026-03-27 05:33:32 +01:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -0,0 +1,21 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import EmbedderScheme
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.image_embedder import ImageEmbedder
@pytest.mark.unit_test
def test_image_embedder():
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
emb = ImageEmbedder(params=params)
emb.initialize()
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
output = emb(input_data)
assert len(output.shape) == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10

View File

@@ -0,0 +1,22 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
from rl_coach.architectures.mxnet_components.embedders.vector_embedder import VectorEmbedder
from rl_coach.base_parameters import EmbedderScheme
@pytest.mark.unit_test
def test_vector_embedder():
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
emb = VectorEmbedder(params=params)
emb.initialize()
input_data = mx.nd.random.uniform(low=0, high=255, shape=(10, 100))
output = emb(input_data)
assert len(output.shape) == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10
assert output.shape[1] == 256 # since last dense layer has 256 units