mirror of
https://github.com/gryf/coach.git
synced 2026-02-17 06:35:47 +01:00
Adding mxnet components to rl_coach/architectures (#60)
Adding mxnet components to rl_coach architectures. - Supports PPO and DQN - Tested with CartPole_PPO and CarPole_DQN - Normalizing filters don't work right now (see #49) and are disabled in CartPole_PPO preset - Checkpointing is disabled for MXNet
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.image_embedder import ImageEmbedder
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_image_embedder():
|
||||
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
||||
emb = ImageEmbedder(params=params)
|
||||
emb.initialize()
|
||||
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
|
||||
output = emb(input_data)
|
||||
assert len(output.shape) == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
@@ -0,0 +1,22 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.embedder_parameters import InputEmbedderParameters
|
||||
from rl_coach.architectures.mxnet_components.embedders.vector_embedder import VectorEmbedder
|
||||
from rl_coach.base_parameters import EmbedderScheme
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_vector_embedder():
|
||||
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
|
||||
emb = VectorEmbedder(params=params)
|
||||
emb.initialize()
|
||||
input_data = mx.nd.random.uniform(low=0, high=255, shape=(10, 100))
|
||||
output = emb(input_data)
|
||||
assert len(output.shape) == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
assert output.shape[1] == 256 # since last dense layer has 256 units
|
||||
@@ -0,0 +1,406 @@
|
||||
import mxnet as mx
|
||||
import numpy as np
|
||||
import os
|
||||
import pytest
|
||||
from scipy import stats as sp_stats
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.head_parameters import PPOHeadParameters
|
||||
from rl_coach.architectures.mxnet_components.heads.ppo_head import CategoricalDist, MultivariateNormalDist,\
|
||||
DiscretePPOHead, ClippedPPOLossDiscrete, ClippedPPOLossContinuous, PPOHead
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_multivariate_normal_dist_shape():
|
||||
num_var = 2
|
||||
means = mx.nd.array((0, 1))
|
||||
covar = mx.nd.array(((1, 0),(0, 0.5)))
|
||||
data = mx.nd.array((0.5, 0.8))
|
||||
policy_dist = MultivariateNormalDist(num_var, means, covar)
|
||||
log_probs = policy_dist.log_prob(data)
|
||||
assert log_probs.ndim == 1
|
||||
assert log_probs.shape[0] == 1
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_multivariate_normal_dist_batch_shape():
|
||||
num_var = 2
|
||||
batch_size = 3
|
||||
means = mx.nd.random.uniform(shape=(batch_size, num_var))
|
||||
# create batch of covariance matrices only defined on diagonal
|
||||
std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2)
|
||||
covar = mx.nd.eye(N=num_var) * std
|
||||
data = mx.nd.random.uniform(shape=(batch_size, num_var))
|
||||
policy_dist = MultivariateNormalDist(num_var, means, covar)
|
||||
log_probs = policy_dist.log_prob(data)
|
||||
assert log_probs.ndim == 1
|
||||
assert log_probs.shape[0] == batch_size
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_multivariate_normal_dist_batch_time_shape():
|
||||
num_var = 2
|
||||
batch_size = 3
|
||||
time_steps = 4
|
||||
means = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var))
|
||||
# create batch (per time step) of covariance matrices only defined on diagonal
|
||||
std = mx.nd.array((1, 0.5)).broadcast_like(means).expand_dims(-2)
|
||||
covar = mx.nd.eye(N=num_var) * std
|
||||
data = mx.nd.random.uniform(shape=(batch_size, time_steps, num_var))
|
||||
policy_dist = MultivariateNormalDist(num_var, means, covar)
|
||||
log_probs = policy_dist.log_prob(data)
|
||||
assert log_probs.ndim == 2
|
||||
assert log_probs.shape[0] == batch_size
|
||||
assert log_probs.shape[1] == time_steps
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_multivariate_normal_dist_kl_div():
|
||||
n_classes = 2
|
||||
dist_a = MultivariateNormalDist(num_var=n_classes,
|
||||
mean = mx.nd.array([0.2, 0.8]).expand_dims(0),
|
||||
sigma = mx.nd.array([[1, 0.5], [0.5, 0.5]]).expand_dims(0))
|
||||
dist_b = MultivariateNormalDist(num_var=n_classes,
|
||||
mean = mx.nd.array([0.3, 0.7]).expand_dims(0),
|
||||
sigma = mx.nd.array([[1, 0.2], [0.2, 0.5]]).expand_dims(0))
|
||||
|
||||
actual = dist_a.kl_div(dist_b).asnumpy()
|
||||
np.testing.assert_almost_equal(actual=actual, desired=0.195100128)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_multivariate_normal_dist_kl_div_batch():
|
||||
n_classes = 2
|
||||
dist_a = MultivariateNormalDist(num_var=n_classes,
|
||||
mean = mx.nd.array([[0.2, 0.8],
|
||||
[0.2, 0.8]]),
|
||||
sigma = mx.nd.array([[[1, 0.5], [0.5, 0.5]],
|
||||
[[1, 0.5], [0.5, 0.5]]]))
|
||||
dist_b = MultivariateNormalDist(num_var=n_classes,
|
||||
mean = mx.nd.array([[0.3, 0.7],
|
||||
[0.3, 0.7]]),
|
||||
sigma = mx.nd.array([[[1, 0.2], [0.2, 0.5]],
|
||||
[[1, 0.2], [0.2, 0.5]]]))
|
||||
|
||||
actual = dist_a.kl_div(dist_b).asnumpy()
|
||||
np.testing.assert_almost_equal(actual=actual, desired=[0.195100128, 0.195100128])
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_shape():
|
||||
num_actions = 2
|
||||
# actions taken, of shape (batch_size, time_steps)
|
||||
actions = mx.nd.array((1,))
|
||||
# action probabilities, of shape (batch_size, time_steps, num_actions)
|
||||
policy_probs = mx.nd.array((0.8, 0.2))
|
||||
policy_dist = CategoricalDist(num_actions, policy_probs)
|
||||
action_probs = policy_dist.log_prob(actions)
|
||||
assert action_probs.ndim == 1
|
||||
assert action_probs.shape[0] == 1
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_batch_shape():
|
||||
batch_size = 3
|
||||
num_actions = 2
|
||||
# actions taken, of shape (batch_size, time_steps)
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
# action probabilities, of shape (batch_size, time_steps, num_actions)
|
||||
policy_probs = mx.nd.array(((0.8, 0.2), (0.5, 0.5), (0.5, 0.5)))
|
||||
policy_dist = CategoricalDist(num_actions, policy_probs)
|
||||
action_probs = policy_dist.log_prob(actions)
|
||||
assert action_probs.ndim == 1
|
||||
assert action_probs.shape[0] == batch_size
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_batch_time_shape():
|
||||
batch_size = 3
|
||||
time_steps = 4
|
||||
num_actions = 2
|
||||
# actions taken, of shape (batch_size, time_steps)
|
||||
actions = mx.nd.array(((0, 1, 0, 0),
|
||||
(1, 1, 0, 0),
|
||||
(0, 0, 0, 0)))
|
||||
# action probabilities, of shape (batch_size, time_steps, num_actions)
|
||||
policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
|
||||
policy_dist = CategoricalDist(num_actions, policy_probs)
|
||||
action_probs = policy_dist.log_prob(actions)
|
||||
assert action_probs.ndim == 2
|
||||
assert action_probs.shape[0] == batch_size
|
||||
assert action_probs.shape[1] == time_steps
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_batch():
|
||||
n_classes = 2
|
||||
probs = mx.nd.array(((0.8, 0.2),
|
||||
(0.7, 0.3),
|
||||
(0.5, 0.5)))
|
||||
|
||||
dist = CategoricalDist(n_classes, probs)
|
||||
# check log_prob
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
manual_log_prob = np.array((-0.22314353, -1.20397282, -0.69314718))
|
||||
np.testing.assert_almost_equal(actual=dist.log_prob(actions).asnumpy(), desired=manual_log_prob)
|
||||
# check entropy
|
||||
sp_entropy = np.array([sp_stats.entropy(pk=(0.8, 0.2)),
|
||||
sp_stats.entropy(pk=(0.7, 0.3)),
|
||||
sp_stats.entropy(pk=(0.5, 0.5))])
|
||||
np.testing.assert_almost_equal(actual=dist.entropy().asnumpy(), desired=sp_entropy)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_kl_div():
|
||||
n_classes = 3
|
||||
dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.4, 0.2, 0.4]))
|
||||
dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.3, 0.4, 0.3]))
|
||||
dist_c = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.2, 0.6, 0.2]))
|
||||
dist_d = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([0.0, 1.0, 0.0]))
|
||||
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_b).asnumpy(), desired=0.09151624)
|
||||
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_c).asnumpy(), desired=0.33479536)
|
||||
np.testing.assert_almost_equal(actual=dist_c.kl_div(dist_a).asnumpy(), desired=0.38190854)
|
||||
np.testing.assert_almost_equal(actual=dist_a.kl_div(dist_d).asnumpy(), desired=np.nan)
|
||||
np.testing.assert_almost_equal(actual=dist_d.kl_div(dist_a).asnumpy(), desired=1.60943782)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_categorical_dist_kl_div_batch():
|
||||
n_classes = 3
|
||||
dist_a = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.4, 0.2, 0.4],
|
||||
[0.4, 0.2, 0.4],
|
||||
[0.4, 0.2, 0.4]]))
|
||||
dist_b = CategoricalDist(n_classes=n_classes, probs=mx.nd.array([[0.3, 0.4, 0.3],
|
||||
[0.3, 0.4, 0.3],
|
||||
[0.3, 0.4, 0.3]]))
|
||||
actual = dist_a.kl_div(dist_b).asnumpy()
|
||||
np.testing.assert_almost_equal(actual=actual, desired=[0.09151624, 0.09151624, 0.09151624])
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_continuous_batch():
|
||||
# check lower loss for policy with better probabilities:
|
||||
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
|
||||
loss_fn = ClippedPPOLossContinuous(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=0.2)
|
||||
loss_fn.initialize()
|
||||
# actual actions taken, of shape (batch_size)
|
||||
actions = mx.nd.array(((0.5, -0.5), (0.2, 0.3), (0.4, 2.0)))
|
||||
# advantages from taking action, of shape (batch_size)
|
||||
advantages = mx.nd.array((2, -2, 1))
|
||||
# action probabilities, of shape (batch_size, num_actions)
|
||||
old_policy_means = mx.nd.array(((1, 0), (0, 0), (-1, 0)))
|
||||
new_policy_means_worse = mx.nd.array(((2, 0), (0, 0), (-1, 0)))
|
||||
new_policy_means_better = mx.nd.array(((0.5, 0), (0, 0), (-1, 0)))
|
||||
|
||||
policy_stds = mx.nd.array(((1, 1), (1, 1), (1, 1)))
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
|
||||
loss_worse = loss_fn(new_policy_means_worse, policy_stds,
|
||||
actions, old_policy_means, policy_stds,
|
||||
clip_param_rescaler, advantages)
|
||||
loss_better = loss_fn(new_policy_means_better, policy_stds,
|
||||
actions, old_policy_means, policy_stds,
|
||||
clip_param_rescaler, advantages)
|
||||
|
||||
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_discrete_batch():
|
||||
# check lower loss for policy with better probabilities:
|
||||
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
|
||||
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=None,
|
||||
use_kl_regularization=True,
|
||||
initial_kl_coefficient=1)
|
||||
loss_fn.initialize()
|
||||
|
||||
# actual actions taken, of shape (batch_size)
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
# advantages from taking action, of shape (batch_size)
|
||||
advantages = mx.nd.array((-2, 2, 1))
|
||||
# action probabilities, of shape (batch_size, num_actions)
|
||||
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6)))
|
||||
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
|
||||
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
|
||||
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse
|
||||
assert lw_loss.ndim == 1
|
||||
assert lw_loss.shape[0] == 1
|
||||
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better
|
||||
assert lb_loss.ndim == 1
|
||||
assert lb_loss.shape[0] == 1
|
||||
assert lw_loss > lb_loss
|
||||
assert lw_kl > lb_kl
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_discrete_batch_kl_div():
|
||||
# check lower loss for policy with better probabilities:
|
||||
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
|
||||
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=None,
|
||||
use_kl_regularization=True,
|
||||
initial_kl_coefficient=0.5)
|
||||
loss_fn.initialize()
|
||||
|
||||
# actual actions taken, of shape (batch_size)
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
# advantages from taking action, of shape (batch_size)
|
||||
advantages = mx.nd.array((-2, 2, 1))
|
||||
# action probabilities, of shape (batch_size, num_actions)
|
||||
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs_worse = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs_better = mx.nd.array(((0.5, 0.5), (0.2, 0.8), (0.4, 0.6)))
|
||||
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
|
||||
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
|
||||
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
lw_loss, lw_reg, lw_kl, lw_ent, lw_lr, lw_clip_lr = loss_worse
|
||||
assert lw_kl.ndim == 1
|
||||
assert lw_kl.shape[0] == 1
|
||||
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
lb_loss, lb_reg, lb_kl, lb_ent, lb_lr, lb_clip_lr = loss_better
|
||||
assert lb_kl.ndim == 1
|
||||
assert lb_kl.shape[0] == 1
|
||||
assert lw_kl > lb_kl
|
||||
assert lw_reg > lb_reg
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_discrete_batch_time():
|
||||
batch_size = 3
|
||||
time_steps = 4
|
||||
num_actions = 2
|
||||
|
||||
# actions taken, of shape (batch_size, time_steps)
|
||||
actions = mx.nd.array(((0, 1, 0, 0),
|
||||
(1, 1, 0, 0),
|
||||
(0, 0, 0, 0)))
|
||||
# advantages from taking action, of shape (batch_size, time_steps)
|
||||
advantages = mx.nd.array(((-2, 2, 1, 0),
|
||||
(-1, 1, 0, 1),
|
||||
(-1, 0, 1, 0)))
|
||||
# action probabilities, of shape (batch_size, num_actions)
|
||||
old_policy_probs = mx.nd.array((((0.8, 0.2), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
|
||||
new_policy_probs_worse = mx.nd.array((((0.9, 0.1), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
|
||||
new_policy_probs_better = mx.nd.array((((0.2, 0.8), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5)),
|
||||
((0.5, 0.5), (0.5, 0.5), (0.5, 0.5), (0.5, 0.5))))
|
||||
|
||||
# check lower loss for policy with better probabilities:
|
||||
# i.e. higher probability on high advantage actions, low probability on low advantage actions.
|
||||
loss_fn = ClippedPPOLossDiscrete(num_actions=num_actions,
|
||||
clip_likelihood_ratio_using_epsilon=0.2)
|
||||
loss_fn.initialize()
|
||||
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
|
||||
loss_worse = loss_fn(new_policy_probs_worse, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
loss_better = loss_fn(new_policy_probs_better, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
|
||||
assert len(loss_worse) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 6 # (LOSS, REGULARIZATION, KL, ENTROPY, LIKELIHOOD_RATIO, CLIPPED_LIKELIHOOD_RATIO)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_discrete_weight():
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
advantages = mx.nd.array((-2, 2, 1))
|
||||
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
|
||||
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=0.2)
|
||||
loss_fn.initialize()
|
||||
loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
loss_fn_weighted = ClippedPPOLossDiscrete(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=0.2,
|
||||
weight=0.5)
|
||||
loss_fn_weighted.initialize()
|
||||
loss_weighted = loss_fn_weighted(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
assert loss[0] == loss_weighted[0] * 2
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clipped_ppo_loss_discrete_hybridize():
|
||||
loss_fn = ClippedPPOLossDiscrete(num_actions=2,
|
||||
clip_likelihood_ratio_using_epsilon=0.2)
|
||||
loss_fn.initialize()
|
||||
loss_fn.hybridize()
|
||||
actions = mx.nd.array((0, 1, 0))
|
||||
advantages = mx.nd.array((-2, 2, 1))
|
||||
old_policy_probs = mx.nd.array(((0.7, 0.3), (0.2, 0.8), (0.4, 0.6)))
|
||||
new_policy_probs = mx.nd.array(((0.9, 0.1), (0.2, 0.8), (0.4, 0.6)))
|
||||
clip_param_rescaler = mx.nd.array((1,))
|
||||
|
||||
loss = loss_fn(new_policy_probs, actions, old_policy_probs, clip_param_rescaler, advantages)
|
||||
assert loss[0] == mx.nd.array((-0.142857153,))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_discrete_ppo_head():
|
||||
head = DiscretePPOHead(num_actions=2)
|
||||
head.initialize()
|
||||
middleware_data = mx.nd.random.uniform(shape=(10, 100))
|
||||
probs = head(middleware_data)
|
||||
assert probs.ndim == 2 # (batch_size, num_actions)
|
||||
assert probs.shape[0] == 10 # since batch_size is 10
|
||||
assert probs.shape[1] == 2 # since num_actions is 2
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_head():
|
||||
agent_parameters = ClippedPPOAgentParameters()
|
||||
num_actions = 5
|
||||
action_space = DiscreteActionSpace(num_actions=num_actions)
|
||||
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
|
||||
head = PPOHead(agent_parameters=agent_parameters,
|
||||
spaces=spaces,
|
||||
network_name="test_ppo_head")
|
||||
|
||||
head.initialize()
|
||||
|
||||
batch_size = 15
|
||||
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
|
||||
probs = head(middleware_data)
|
||||
assert probs.ndim == 2 # (batch_size, num_actions)
|
||||
assert probs.shape[0] == batch_size
|
||||
assert probs.shape[1] == num_actions
|
||||
@@ -0,0 +1,90 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.mxnet_components.heads.ppo_v_head import PPOVHead, PPOVHeadLoss
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head_loss_batch():
|
||||
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1)
|
||||
total_return = mx.nd.array((5, -3, 0))
|
||||
old_policy_values = mx.nd.array((3, -1, -1))
|
||||
new_policy_values_worse = mx.nd.array((2, 0, -1))
|
||||
new_policy_values_better = mx.nd.array((4, -2, -1))
|
||||
|
||||
loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return)
|
||||
loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return)
|
||||
|
||||
assert len(loss_worse) == 1 # (LOSS)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 1 # (LOSS)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head_loss_batch_time():
|
||||
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.1)
|
||||
total_return = mx.nd.array(((3, 1, 1, 0),
|
||||
(1, 0, 0, 1),
|
||||
(3, 0, 1, 0)))
|
||||
old_policy_values = mx.nd.array(((2, 1, 1, 0),
|
||||
(1, 0, 0, 1),
|
||||
(0, 0, 1, 0)))
|
||||
new_policy_values_worse = mx.nd.array(((2, 1, 1, 0),
|
||||
(1, 0, 0, 1),
|
||||
(2, 0, 1, 0)))
|
||||
new_policy_values_better = mx.nd.array(((3, 1, 1, 0),
|
||||
(1, 0, 0, 1),
|
||||
(2, 0, 1, 0)))
|
||||
|
||||
loss_worse = loss_fn(new_policy_values_worse, old_policy_values, total_return)
|
||||
loss_better = loss_fn(new_policy_values_better, old_policy_values, total_return)
|
||||
|
||||
assert len(loss_worse) == 1 # (LOSS)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 1 # (LOSS)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head_loss_weight():
|
||||
total_return = mx.nd.array((5, -3, 0))
|
||||
old_policy_values = mx.nd.array((3, -1, -1))
|
||||
new_policy_values = mx.nd.array((4, -2, -1))
|
||||
loss_fn = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=1)
|
||||
loss = loss_fn(new_policy_values, old_policy_values, total_return)
|
||||
loss_fn_weighted = PPOVHeadLoss(clip_likelihood_ratio_using_epsilon=0.2, weight=0.5)
|
||||
loss_weighted = loss_fn_weighted(new_policy_values, old_policy_values, total_return)
|
||||
assert loss[0].sum() == loss_weighted[0].sum() * 2
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head():
|
||||
agent_parameters = ClippedPPOAgentParameters()
|
||||
action_space = DiscreteActionSpace(num_actions=5)
|
||||
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
|
||||
value_net = PPOVHead(agent_parameters=agent_parameters,
|
||||
spaces=spaces,
|
||||
network_name="test_ppo_v_head")
|
||||
value_net.initialize()
|
||||
batch_size = 15
|
||||
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
|
||||
values = value_net(middleware_data)
|
||||
assert values.ndim == 1 # (batch_size)
|
||||
assert values.shape[0] == batch_size
|
||||
@@ -0,0 +1,60 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.mxnet_components.heads.q_head import QHead, QHeadLoss
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_q_head_loss():
|
||||
loss_fn = QHeadLoss()
|
||||
# example with batch_size of 3, and num_actions of 2
|
||||
target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2)))
|
||||
pred_q_values_worse = mx.nd.array(((6, 5), (-1, -2), (0, 2)))
|
||||
pred_q_values_better = mx.nd.array(((4, 5), (-2, -2), (1, 2)))
|
||||
loss_worse = loss_fn(pred_q_values_worse, target_q_values)
|
||||
loss_better = loss_fn(pred_q_values_better, target_q_values)
|
||||
assert len(loss_worse) == 1 # (LOSS)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 1 # (LOSS)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_v_head_loss_weight():
|
||||
target_q_values = mx.nd.array(((3, 5), (-1, -2), (0, 2)))
|
||||
pred_q_values = mx.nd.array(((4, 5), (-2, -2), (1, 2)))
|
||||
loss_fn = QHeadLoss()
|
||||
loss = loss_fn(pred_q_values, target_q_values)
|
||||
loss_fn_weighted = QHeadLoss(weight=0.5)
|
||||
loss_weighted = loss_fn_weighted(pred_q_values, target_q_values)
|
||||
assert loss[0] == loss_weighted[0]*2
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head():
|
||||
agent_parameters = ClippedPPOAgentParameters()
|
||||
num_actions = 5
|
||||
action_space = DiscreteActionSpace(num_actions=num_actions)
|
||||
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
|
||||
value_net = QHead(agent_parameters=agent_parameters,
|
||||
spaces=spaces,
|
||||
network_name="test_q_head")
|
||||
value_net.initialize()
|
||||
batch_size = 15
|
||||
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
|
||||
values = value_net(middleware_data)
|
||||
assert values.ndim == 2 # (batch_size, num_actions)
|
||||
assert values.shape[0] == batch_size
|
||||
assert values.shape[1] == num_actions
|
||||
@@ -0,0 +1,57 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.architectures.mxnet_components.heads.v_head import VHead, VHeadLoss
|
||||
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAlgorithmParameters, ClippedPPOAgentParameters
|
||||
from rl_coach.spaces import SpacesDefinition, DiscreteActionSpace
|
||||
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_v_head_loss():
|
||||
loss_fn = VHeadLoss()
|
||||
target_values = mx.nd.array((3, -1, 0))
|
||||
pred_values_worse = mx.nd.array((0, 0, 1))
|
||||
pred_values_better = mx.nd.array((2, -1, 0))
|
||||
loss_worse = loss_fn(pred_values_worse, target_values)
|
||||
loss_better = loss_fn(pred_values_better, target_values)
|
||||
assert len(loss_worse) == 1 # (LOSS)
|
||||
loss_worse_val = loss_worse[0]
|
||||
assert loss_worse_val.ndim == 1
|
||||
assert loss_worse_val.shape[0] == 1
|
||||
assert len(loss_better) == 1 # (LOSS)
|
||||
loss_better_val = loss_better[0]
|
||||
assert loss_better_val.ndim == 1
|
||||
assert loss_better_val.shape[0] == 1
|
||||
assert loss_worse_val > loss_better_val
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_v_head_loss_weight():
|
||||
target_values = mx.nd.array((3, -1, 0))
|
||||
pred_values = mx.nd.array((0, 0, 1))
|
||||
loss_fn = VHeadLoss()
|
||||
loss = loss_fn(pred_values, target_values)
|
||||
loss_fn_weighted = VHeadLoss(weight=0.5)
|
||||
loss_weighted = loss_fn_weighted(pred_values, target_values)
|
||||
assert loss[0] == loss_weighted[0]*2
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_ppo_v_head():
|
||||
agent_parameters = ClippedPPOAgentParameters()
|
||||
action_space = DiscreteActionSpace(num_actions=5)
|
||||
spaces = SpacesDefinition(state=None, goal=None, action=action_space, reward=None)
|
||||
value_net = VHead(agent_parameters=agent_parameters,
|
||||
spaces=spaces,
|
||||
network_name="test_v_head")
|
||||
value_net.initialize()
|
||||
batch_size = 15
|
||||
middleware_data = mx.nd.random.uniform(shape=(batch_size, 100))
|
||||
values = value_net(middleware_data)
|
||||
assert values.ndim == 1 # (batch_size)
|
||||
assert values.shape[0] == batch_size
|
||||
@@ -0,0 +1,22 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.architectures.middleware_parameters import FCMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.middlewares.fc_middleware import FCMiddleware
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_fc_middleware():
|
||||
params = FCMiddlewareParameters(scheme=MiddlewareScheme.Medium)
|
||||
mid = FCMiddleware(params=params)
|
||||
mid.initialize()
|
||||
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 100))
|
||||
output = mid(embedded_data)
|
||||
assert output.ndim == 2 # since last block was flatten
|
||||
assert output.shape[0] == 10 # since batch_size is 10
|
||||
assert output.shape[1] == 512 # since last layer of middleware (middle scheme) had 512 units
|
||||
@@ -0,0 +1,25 @@
|
||||
import mxnet as mx
|
||||
import os
|
||||
import pytest
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
||||
|
||||
|
||||
from rl_coach.base_parameters import MiddlewareScheme
|
||||
from rl_coach.architectures.middleware_parameters import LSTMMiddlewareParameters
|
||||
from rl_coach.architectures.mxnet_components.middlewares.lstm_middleware import LSTMMiddleware
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_lstm_middleware():
|
||||
params = LSTMMiddlewareParameters(number_of_lstm_cells=25, scheme=MiddlewareScheme.Medium)
|
||||
mid = LSTMMiddleware(params=params)
|
||||
mid.initialize()
|
||||
# NTC
|
||||
embedded_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 15, 20))
|
||||
# NTC -> TNC
|
||||
output = mid(embedded_data)
|
||||
assert output.ndim == 3 # since last block was flatten
|
||||
assert output.shape[0] == 15 # since t is 15
|
||||
assert output.shape[1] == 10 # since batch_size is 10
|
||||
assert output.shape[2] == 25 # since number_of_lstm_cells is 25
|
||||
144
rl_coach/tests/architectures/mxnet_components/test_utils.py
Normal file
144
rl_coach/tests/architectures/mxnet_components/test_utils.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import pytest
|
||||
|
||||
import mxnet as mx
|
||||
from mxnet import nd
|
||||
import numpy as np
|
||||
|
||||
from rl_coach.architectures.mxnet_components.utils import *
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_to_mx_ndarray():
|
||||
# scalar
|
||||
assert to_mx_ndarray(1.2) == nd.array([1.2])
|
||||
# list of one scalar
|
||||
assert to_mx_ndarray([1.2]) == [nd.array([1.2])]
|
||||
# list of multiple scalars
|
||||
assert to_mx_ndarray([1.2, 3.4]) == [nd.array([1.2]), nd.array([3.4])]
|
||||
# list of lists of scalars
|
||||
assert to_mx_ndarray([[1.2], [3.4]]) == [[nd.array([1.2])], [nd.array([3.4])]]
|
||||
# numpy
|
||||
assert np.array_equal(to_mx_ndarray(np.array([[1.2], [3.4]])).asnumpy(), nd.array([[1.2], [3.4]]).asnumpy())
|
||||
# tuple
|
||||
assert to_mx_ndarray(((1.2,), (3.4,))) == ((nd.array([1.2]),), (nd.array([3.4]),))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_asnumpy_or_asscalar():
|
||||
# scalar float32
|
||||
assert asnumpy_or_asscalar(nd.array([1.2])) == np.float32(1.2)
|
||||
# scalar int32
|
||||
assert asnumpy_or_asscalar(nd.array([2], dtype=np.int32)) == np.int32(2)
|
||||
# list of one scalar
|
||||
assert asnumpy_or_asscalar([nd.array([1.2])]) == [np.float32(1.2)]
|
||||
# list of multiple scalars
|
||||
assert asnumpy_or_asscalar([nd.array([1.2]), nd.array([3.4])]) == [np.float32([1.2]), np.float32([3.4])]
|
||||
# list of lists of scalars
|
||||
assert asnumpy_or_asscalar([[nd.array([1.2])], [nd.array([3.4])]]) == [[np.float32([1.2])], [np.float32([3.4])]]
|
||||
# tensor
|
||||
assert np.array_equal(asnumpy_or_asscalar(nd.array([[1.2], [3.4]])), np.array([[1.2], [3.4]], dtype=np.float32))
|
||||
# tuple
|
||||
assert (asnumpy_or_asscalar(((nd.array([1.2]),), (nd.array([3.4]),))) ==
|
||||
((np.array([1.2], dtype=np.float32),), (np.array([3.4], dtype=np.float32),)))
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_global_norm():
|
||||
data = list()
|
||||
for i in range(1, 6):
|
||||
data.append(np.ones((i * 10, i * 10)) * i)
|
||||
gnorm = np.asscalar(np.sqrt(sum([np.sum(np.square(d)) for d in data])))
|
||||
assert np.isclose(gnorm, global_norm([nd.array(d) for d in data]).asscalar())
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_split_outputs_per_head():
|
||||
class TestHead:
|
||||
def __init__(self, num_outputs):
|
||||
self.num_outputs = num_outputs
|
||||
|
||||
assert split_outputs_per_head((1, 2, 3, 4), [TestHead(2), TestHead(1), TestHead(1)]) == [[1, 2], [3], [4]]
|
||||
|
||||
|
||||
class DummySchema:
|
||||
def __init__(self, num_head_outputs, num_agent_inputs, num_targets):
|
||||
self.head_outputs = ['head_output_{}'.format(i) for i in range(num_head_outputs)]
|
||||
self.agent_inputs = ['agent_input_{}'.format(i) for i in range(num_agent_inputs)]
|
||||
self.targets = ['target_{}'.format(i) for i in range(num_targets)]
|
||||
|
||||
|
||||
class DummyLoss:
|
||||
def __init__(self, num_head_outputs, num_agent_inputs, num_targets):
|
||||
self.input_schema = DummySchema(num_head_outputs, num_agent_inputs, num_targets)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_split_targets_per_loss():
|
||||
assert split_targets_per_loss([1, 2, 3, 4],
|
||||
[DummyLoss(10, 100, 2), DummyLoss(20, 200, 1), DummyLoss(30, 300, 1)]) == \
|
||||
[[1, 2], [3], [4]]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_get_loss_agent_inputs():
|
||||
input_dict = {'output_0_0': [1, 2], 'output_0_1': [3, 4], 'output_1_0': [5]}
|
||||
assert get_loss_agent_inputs(input_dict, 0, DummyLoss(10, 2, 100)) == [[1, 2], [3, 4]]
|
||||
assert get_loss_agent_inputs(input_dict, 1, DummyLoss(20, 1, 200)) == [[5]]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_align_loss_args():
|
||||
class TestLossFwd(DummyLoss):
|
||||
def __init__(self, num_targets, num_agent_inputs, num_head_outputs):
|
||||
super(TestLossFwd, self).__init__(num_targets, num_agent_inputs, num_head_outputs)
|
||||
|
||||
def loss_forward(self, F, head_output_2, head_output_1, agent_input_2, target_0, agent_input_1, param1, param2):
|
||||
pass
|
||||
|
||||
assert align_loss_args([1, 2, 3], [4, 5, 6, 7], [8, 9], TestLossFwd(3, 4, 2)) == [3, 2, 6, 8, 5]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_to_tuple():
|
||||
assert to_tuple(123) == (123,)
|
||||
assert to_tuple((1, 2, 3)) == (1, 2, 3)
|
||||
assert to_tuple([1, 2, 3]) == (1, 2, 3)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_to_list():
|
||||
assert to_list(123) == [123]
|
||||
assert to_list((1, 2, 3)) == [1, 2, 3]
|
||||
assert to_list([1, 2, 3]) == [1, 2, 3]
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_loss_output_dict():
|
||||
assert loss_output_dict([1, 2, 3], ['loss', 'loss', 'reg']) == {'loss': [1, 2], 'reg': [3]}
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_clip_grad():
|
||||
a = np.array([1, 2, -3])
|
||||
b = np.array([4, 5, -6])
|
||||
clip = 2
|
||||
gscale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(a)) + np.sum(np.square(b))))
|
||||
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByGlobalNorm, clip_val=clip),
|
||||
[a, b]):
|
||||
assert np.allclose(lhs.asnumpy(), rhs * gscale)
|
||||
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByValue, clip_val=clip),
|
||||
[a, b]):
|
||||
assert np.allclose(lhs.asnumpy(), np.clip(rhs, -clip, clip))
|
||||
for lhs, rhs in zip(clip_grad([nd.array(a), nd.array(b)], GradientClippingMethod.ClipByNorm, clip_val=clip),
|
||||
[a, b]):
|
||||
scale = np.minimum(1.0, clip / np.sqrt(np.sum(np.square(rhs))))
|
||||
assert np.allclose(lhs.asnumpy(), rhs * scale)
|
||||
|
||||
|
||||
@pytest.mark.unit_test
|
||||
def test_hybrid_clip():
|
||||
x = mx.nd.array((0.5, 1.5, 2.5))
|
||||
a = mx.nd.array((1,))
|
||||
b = mx.nd.array((2,))
|
||||
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
||||
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
||||
Reference in New Issue
Block a user