From 19a68812f6a1a7f4bbab906a880acc27b047359c Mon Sep 17 00:00:00 2001 From: Sina Afrooze Date: Sun, 25 Nov 2018 01:23:18 -0800 Subject: [PATCH] Added ONNX compatible broadcast_like function (#152) - Also simplified the hybrid_clip implementation. --- .../mxnet_components/heads/ppo_head.py | 7 +++---- .../architectures/mxnet_components/utils.py | 18 ++++++++++++++---- .../mxnet_components/test_utils.py | 9 ++++++++- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/rl_coach/architectures/mxnet_components/heads/ppo_head.py b/rl_coach/architectures/mxnet_components/heads/ppo_head.py index 269aec6..18fb40f 100644 --- a/rl_coach/architectures/mxnet_components/heads/ppo_head.py +++ b/rl_coach/architectures/mxnet_components/heads/ppo_head.py @@ -11,7 +11,7 @@ from rl_coach.utils import eps from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\ NormalizedRSSInitializer from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION -from rl_coach.architectures.mxnet_components.utils import hybrid_clip +from rl_coach.architectures.mxnet_components.utils import hybrid_clip, broadcast_like LOSS_OUT_TYPE_KL = 'kl_divergence' @@ -146,7 +146,7 @@ class MultivariateNormalDist: sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma)) term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma) # sum of diagonal for batch of matrices - term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1) + term1 = (broadcast_like(self.F, self.F.eye(self.num_var), term1a) * term1a).sum(axis=-1).sum(axis=-1) mean_diff = (alt_dist.mean - self.mean).expand_dims(-1) mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2) term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1) @@ -155,7 +155,6 @@ class MultivariateNormalDist: return 0.5 * (term1 + term2 - self.num_var + term3) - class CategoricalDist: def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None: """ @@ -284,7 +283,7 @@ class ContinuousPPOHead(nn.HybridBlock): of shape (batch_size, time_step, action_mean). """ policy_means = self.dense(x) - policy_std = log_std.exp().expand_dims(0).broadcast_like(policy_means) + policy_std = broadcast_like(F, log_std.exp().expand_dims(0), policy_means) return policy_means, policy_std diff --git a/rl_coach/architectures/mxnet_components/utils.py b/rl_coach/architectures/mxnet_components/utils.py index bf243dd..6388b82 100644 --- a/rl_coach/architectures/mxnet_components/utils.py +++ b/rl_coach/architectures/mxnet_components/utils.py @@ -253,13 +253,23 @@ def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upp :param clip_upper: upper bound used for clipping, should be of shape (1,) :return: clipped data """ - x_clip_lower = clip_lower.broadcast_like(x) - x_clip_upper = clip_upper.broadcast_like(x) - x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0) - x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0) + x_clip_lower = broadcast_like(F, clip_lower, x) + x_clip_upper = broadcast_like(F, clip_upper, x) + x_clipped = F.minimum(F.maximum(x, x_clip_lower), x_clip_upper) return x_clipped +def broadcast_like(F: ModuleType, x: nd_sym_type, y: nd_sym_type) -> nd_sym_type: + """ + Implementation of broadcast_like using broadcast_add and broadcast_mul because ONNX doesn't support broadcast_like. + :param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized). + :param x: input to be broadcast + :param y: tensor to broadcast x like + :return: broadcast x + """ + return F.broadcast_mul(x, (y * 0) + 1) + + def get_mxnet_activation_name(activation_name: str): """ Convert coach activation name to mxnet specific activation name diff --git a/rl_coach/tests/architectures/mxnet_components/test_utils.py b/rl_coach/tests/architectures/mxnet_components/test_utils.py index 2765998..b88d24e 100644 --- a/rl_coach/tests/architectures/mxnet_components/test_utils.py +++ b/rl_coach/tests/architectures/mxnet_components/test_utils.py @@ -141,7 +141,14 @@ def test_hybrid_clip(): a = mx.nd.array((1,)) b = mx.nd.array((2,)) clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b) - assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all() + assert (np.isclose(a=clipped.asnumpy(), b=(1, 1.5, 2))).all() + + +@pytest.mark.unit_test +def test_broadcast_like(): + x = nd.ones((1, 2)) * 10 + y = nd.ones((100, 100, 2)) * 20 + assert mx.test_utils.almost_equal(x.broadcast_like(y).asnumpy(), broadcast_like(nd, x, y).asnumpy()) @pytest.mark.unit_test