mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 03:30:19 +01:00
Added ONNX compatible broadcast_like function (#152)
- Also simplified the hybrid_clip implementation.
This commit is contained in:
committed by
Gal Leibovich
parent
8df425b6e1
commit
19a68812f6
@@ -11,7 +11,7 @@ from rl_coach.utils import eps
|
|||||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||||
NormalizedRSSInitializer
|
NormalizedRSSInitializer
|
||||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
from rl_coach.architectures.mxnet_components.utils import hybrid_clip, broadcast_like
|
||||||
|
|
||||||
|
|
||||||
LOSS_OUT_TYPE_KL = 'kl_divergence'
|
LOSS_OUT_TYPE_KL = 'kl_divergence'
|
||||||
@@ -146,7 +146,7 @@ class MultivariateNormalDist:
|
|||||||
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
|
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
|
||||||
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
|
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
|
||||||
# sum of diagonal for batch of matrices
|
# sum of diagonal for batch of matrices
|
||||||
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
term1 = (broadcast_like(self.F, self.F.eye(self.num_var), term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
||||||
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
|
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
|
||||||
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
|
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
|
||||||
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
|
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
|
||||||
@@ -155,7 +155,6 @@ class MultivariateNormalDist:
|
|||||||
return 0.5 * (term1 + term2 - self.num_var + term3)
|
return 0.5 * (term1 + term2 - self.num_var + term3)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CategoricalDist:
|
class CategoricalDist:
|
||||||
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
|
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -284,7 +283,7 @@ class ContinuousPPOHead(nn.HybridBlock):
|
|||||||
of shape (batch_size, time_step, action_mean).
|
of shape (batch_size, time_step, action_mean).
|
||||||
"""
|
"""
|
||||||
policy_means = self.dense(x)
|
policy_means = self.dense(x)
|
||||||
policy_std = log_std.exp().expand_dims(0).broadcast_like(policy_means)
|
policy_std = broadcast_like(F, log_std.exp().expand_dims(0), policy_means)
|
||||||
return policy_means, policy_std
|
return policy_means, policy_std
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -253,13 +253,23 @@ def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upp
|
|||||||
:param clip_upper: upper bound used for clipping, should be of shape (1,)
|
:param clip_upper: upper bound used for clipping, should be of shape (1,)
|
||||||
:return: clipped data
|
:return: clipped data
|
||||||
"""
|
"""
|
||||||
x_clip_lower = clip_lower.broadcast_like(x)
|
x_clip_lower = broadcast_like(F, clip_lower, x)
|
||||||
x_clip_upper = clip_upper.broadcast_like(x)
|
x_clip_upper = broadcast_like(F, clip_upper, x)
|
||||||
x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0)
|
x_clipped = F.minimum(F.maximum(x, x_clip_lower), x_clip_upper)
|
||||||
x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0)
|
|
||||||
return x_clipped
|
return x_clipped
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_like(F: ModuleType, x: nd_sym_type, y: nd_sym_type) -> nd_sym_type:
|
||||||
|
"""
|
||||||
|
Implementation of broadcast_like using broadcast_add and broadcast_mul because ONNX doesn't support broadcast_like.
|
||||||
|
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||||
|
:param x: input to be broadcast
|
||||||
|
:param y: tensor to broadcast x like
|
||||||
|
:return: broadcast x
|
||||||
|
"""
|
||||||
|
return F.broadcast_mul(x, (y * 0) + 1)
|
||||||
|
|
||||||
|
|
||||||
def get_mxnet_activation_name(activation_name: str):
|
def get_mxnet_activation_name(activation_name: str):
|
||||||
"""
|
"""
|
||||||
Convert coach activation name to mxnet specific activation name
|
Convert coach activation name to mxnet specific activation name
|
||||||
|
|||||||
@@ -141,7 +141,14 @@ def test_hybrid_clip():
|
|||||||
a = mx.nd.array((1,))
|
a = mx.nd.array((1,))
|
||||||
b = mx.nd.array((2,))
|
b = mx.nd.array((2,))
|
||||||
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
clipped = hybrid_clip(F=mx.nd, x=x, clip_lower=a, clip_upper=b)
|
||||||
assert (np.isclose(a= clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
assert (np.isclose(a=clipped.asnumpy(), b=(1, 1.5, 2))).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit_test
|
||||||
|
def test_broadcast_like():
|
||||||
|
x = nd.ones((1, 2)) * 10
|
||||||
|
y = nd.ones((100, 100, 2)) * 20
|
||||||
|
assert mx.test_utils.almost_equal(x.broadcast_like(y).asnumpy(), broadcast_like(nd, x, y).asnumpy())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.unit_test
|
@pytest.mark.unit_test
|
||||||
|
|||||||
Reference in New Issue
Block a user