mirror of
https://github.com/gryf/coach.git
synced 2025-12-18 11:40:18 +01:00
Added ONNX compatible broadcast_like function (#152)
- Also simplified the hybrid_clip implementation.
This commit is contained in:
committed by
Gal Leibovich
parent
8df425b6e1
commit
19a68812f6
@@ -253,13 +253,23 @@ def hybrid_clip(F: ModuleType, x: nd_sym_type, clip_lower: nd_sym_type, clip_upp
|
||||
:param clip_upper: upper bound used for clipping, should be of shape (1,)
|
||||
:return: clipped data
|
||||
"""
|
||||
x_clip_lower = clip_lower.broadcast_like(x)
|
||||
x_clip_upper = clip_upper.broadcast_like(x)
|
||||
x_clipped = F.stack(x, x_clip_lower, axis=0).max(axis=0)
|
||||
x_clipped = F.stack(x_clipped, x_clip_upper, axis=0).min(axis=0)
|
||||
x_clip_lower = broadcast_like(F, clip_lower, x)
|
||||
x_clip_upper = broadcast_like(F, clip_upper, x)
|
||||
x_clipped = F.minimum(F.maximum(x, x_clip_lower), x_clip_upper)
|
||||
return x_clipped
|
||||
|
||||
|
||||
def broadcast_like(F: ModuleType, x: nd_sym_type, y: nd_sym_type) -> nd_sym_type:
|
||||
"""
|
||||
Implementation of broadcast_like using broadcast_add and broadcast_mul because ONNX doesn't support broadcast_like.
|
||||
:param F: backend api, either `mxnet.nd` or `mxnet.sym` (if block has been hybridized).
|
||||
:param x: input to be broadcast
|
||||
:param y: tensor to broadcast x like
|
||||
:return: broadcast x
|
||||
"""
|
||||
return F.broadcast_mul(x, (y * 0) + 1)
|
||||
|
||||
|
||||
def get_mxnet_activation_name(activation_name: str):
|
||||
"""
|
||||
Convert coach activation name to mxnet specific activation name
|
||||
|
||||
Reference in New Issue
Block a user