mirror of
https://github.com/gryf/coach.git
synced 2026-04-10 15:13:40 +02:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.middlewares.fc_middleware import FCMiddleware
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_fc_middleware():
|
||||
params = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
mid = FCMiddleware(params=params)
|
||||
mid.initialize()
|
||||
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 100))
|
||||
output = mid(embedded_data)
|
||||
assert output.ndim == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
assert output.shape[1] == 512 # since last layer of middleware (middle scheme) had 512 units
|
||||
@@ -0,0 +1,25 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.middlewares.lstm_middleware import LSTMMiddleware
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_lstm_middleware():
|
||||
params = LSTMMiddlewareParameters(number_of_lstm_cells=25, scheme=MiddlewareScheme.Medium)
|
||||
mid = LSTMMiddleware(params=params)
|
||||
mid.initialize()
|
||||
# NTC
|
||||
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 15, 20))
|
||||
# NTC -> TNC
|
||||
output = mid(embedded_data)
|
||||
assert output.ndim == 3 # since last block was flatten
|
||||
assert output.shape[0] == 15 # since t is 15
|
||||
assert output.shape[1] == 10 # since batch_size is 10
|
||||
assert output.shape[2] == 25 # since number_of_lstm_cells is 25
|
||||
Reference in New Issue
Block a user