mirror of
https://github.com/gryf/coach.git
synced 2026-04-19 22:23:32 +02:00
Added ONNX compatible broadcast_like function (#152)
- Also simplified the hybrid_clip implementation.
This commit is contained in:
committed by
Gal Leibovich
parent
8df425b6e1
commit
19a68812f6
@@ -11,7 +11,7 @@ from rl_coach.utils import eps
|
||||
from rl_coach.architectures.mxnet_components.heads.head import Head, HeadLoss, LossInputSchema,\
|
||||
NormalizedRSSInitializer
|
||||
from rl_coach.architectures.mxnet_components.heads.head import LOSS_OUT_TYPE_LOSS, LOSS_OUT_TYPE_REGULARIZATION
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip
|
||||
from rl_coach.architectures.mxnet_components.utils import hybrid_clip, broadcast_like
|
||||
|
||||
|
||||
LOSS_OUT_TYPE_KL = 'kl_divergence'
|
||||
@@ -146,7 +146,7 @@ class MultivariateNormalDist:
|
||||
sigma_b_inv = self.F.linalg.potri(self.F.linalg.potrf(alt_dist.sigma))
|
||||
term1a = mx.nd.batch_dot(sigma_b_inv, self.sigma)
|
||||
# sum of diagonal for batch of matrices
|
||||
term1 = (self.F.eye(self.num_var).broadcast_like(term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
||||
term1 = (broadcast_like(self.F, self.F.eye(self.num_var), term1a) * term1a).sum(axis=-1).sum(axis=-1)
|
||||
mean_diff = (alt_dist.mean - self.mean).expand_dims(-1)
|
||||
mean_diff_t = (alt_dist.mean - self.mean).expand_dims(-2)
|
||||
term2 = self.F.batch_dot(self.F.batch_dot(mean_diff_t, sigma_b_inv), mean_diff).reshape_like(term1)
|
||||
@@ -155,7 +155,6 @@ class MultivariateNormalDist:
|
||||
return 0.5 * (term1 + term2 - self.num_var + term3)
|
||||
|
||||
|
||||
|
||||
class CategoricalDist:
|
||||
def __init__(self, n_classes: int, probs: nd_sym_type, F: ModuleType=mx.nd) -> None:
|
||||
"""
|
||||
@@ -284,7 +283,7 @@ class ContinuousPPOHead(nn.HybridBlock):
|
||||
of shape (batch_size, time_step, action_mean).
|
||||
"""
|
||||
policy_means = self.dense(x)
|
||||
policy_std = log_std.exp().expand_dims(0).broadcast_like(policy_means)
|
||||
policy_std = broadcast_like(F, log_std.exp().expand_dims(0), policy_means)
|
||||
return policy_means, policy_std
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user