1
0
mirror of https://github.com/gryf/coach.git synced 2026-04-10 15:13:40 +02:00

Adding mxnet components to rl_coach/architectures (#60)

Adding mxnet components to rl_coach architectures.

- Supports PPO and DQN
- Tested with CartPole_PPO and CarPole_DQN
- Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset
- Checkpointing is disabled for MXNet
This commit is contained in:
Sina Afrooze
2018-11-07 07:07:15 -08:00
committed by Itai Caspi
parent e7a91b4dc3
commit 5fadb9c18e
39 changed files with 3864 additions and 44 deletions

View File

@@ -0,0 +1,22 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
from rl_coach.architectures.mxnet_components.middlewares.fc_middleware import FCMiddleware
@pytest.mark.unit_test
def test_fc_middleware():
params = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
mid = FCMiddleware(params=params)
mid.initialize()
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 100))
output = mid(embedded_data)
assert output.ndim == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10
assert output.shape[1] == 512 # since last layer of middleware (middle scheme) had 512 units

View File

@@ -0,0 +1,25 @@
import mxnet as mx
import os
import pytest
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from rl_coach.base_parameters import MiddlewareScheme
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
from rl_coach.architectures.mxnet_components.middlewares.lstm_middleware import LSTMMiddleware
@pytest.mark.unit_test
def test_lstm_middleware():
params = LSTMMiddlewareParameters(number_of_lstm_cells=25, scheme=MiddlewareScheme.Medium)
mid = LSTMMiddleware(params=params)
mid.initialize()
# NTC
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 15, 20))
# NTC -> TNC
output = mid(embedded_data)
assert output.ndim == 3 # since last block was flatten
assert output.shape[0] == 15 # since t is 15
assert output.shape[1] == 10 # since batch_size is 10
assert output.shape[2] == 25 # since number_of_lstm_cells is 25