Skip to content

Commit b3ec873

Browse files
fix(atenlib): merge aten_clamp_min/max tensor and scalar version; repeat (#262)
- Merge `aten_clamp_min`, `aten_clamp_max`'s scalar and tensor versions by adding logic on the input rank. - Fix repeat on the empty `repeats` case. Co-authored-by: Jay Zhang <[email protected]>
1 parent c4a655e commit b3ec873

File tree

3 files changed

+48
-54
lines changed

3 files changed

+48
-54
lines changed

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -765,35 +765,42 @@ def aten_clamp(self: TensorType, min_=None, max_=None) -> TensorType:
765765
return clamped
766766

767767

768-
@torch_op("aten::clamp_max.Scalar", overload=True)
769-
def aten_clamp_max_scalar(self, max_):
770-
# clamp_max(Tensor self, Scalar max) -> Tensor
771-
772-
max_ = op.CastLike(max_, self)
773-
return op.Clip(self, None, max_)
774-
775-
776-
@torch_op("aten::clamp_max.Tensor")
777-
def aten_clamp_max_tensor(self, max_):
768+
@torch_op("aten::clamp_max")
769+
def aten_clamp_max(self, max_):
778770
# clamp_max(Tensor self, Tensor max) -> Tensor
779771

780-
return op.Min(self, max_)
781-
772+
self_size = op.Size(self)
773+
max_shape = op.Shape(max_)
774+
max_rank = op.Size(max_shape)
775+
if self_size == 0:
776+
result = op.Expand(self, max_shape)
777+
else:
778+
if max_rank == 0:
779+
max_ = op.CastLike(max_, self)
780+
result = op.Clip(self, None, max_)
781+
else:
782+
result = op.Min(self, max_)
782783

783-
@torch_op("aten::clamp_min.Scalar", overload=True)
784-
def aten_clamp_min_scalar(self, min_):
785-
# clamp_min(Tensor self, Scalar min) -> Tensor
786-
# NOTE: min_ is a rank 0 tensor.
787-
# TODO(justinchuby): Specify the type constraints.
788-
min_ = op.CastLike(min_, self)
789-
return op.Clip(self, min_, None)
784+
return result
790785

791786

792-
@torch_op("aten::clamp_min.Tensor")
793-
def aten_clamp_min_tensor(self, min_):
787+
@torch_op("aten::clamp_min")
788+
def aten_clamp_min(self, min_):
794789
# clamp_min(Tensor self, Tensor min) -> Tensor
795-
# TODO(justinchuby): Specify the type constraints.
796-
return op.Max(self, min_)
790+
791+
self_size = op.Size(self)
792+
min_shape = op.Shape(min_)
793+
min_rank = op.Size(min_shape)
794+
if self_size == 0:
795+
result = op.Expand(self, min_shape)
796+
else:
797+
if min_rank == 0:
798+
min_ = op.CastLike(min_, self)
799+
result = op.Clip(self, min_, None)
800+
else:
801+
result = op.Max(self, min_)
802+
803+
return result
797804

798805

799806
def aten_clone(self: TensorType, memory_format: Optional[str] = None) -> TensorType:
@@ -3976,16 +3983,18 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
39763983
def aten_repeat(self, repeats: INT64):
39773984
# repeat(Tensor self, SymInt[] repeats) -> Tensor
39783985

3979-
# FIXME(justinchuby): When repeats.shape == [0]
3980-
3981-
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
3982-
# shape = ones_like(repeats) := {
3983-
one = op.Constant(value_int=1)
3984-
repeats_shape = op.Shape(repeats)
3985-
shape = op.Expand(one, repeats_shape)
3986-
# }
3987-
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
3988-
return op.Tile(self_expanded, repeats)
3986+
if op.Size(repeats) == 0:
3987+
result = self
3988+
else:
3989+
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
3990+
# shape = ones_like(repeats) := {
3991+
one = op.Constant(value_int=1)
3992+
repeats_shape = op.Shape(repeats)
3993+
shape = op.Expand(one, repeats_shape)
3994+
# }
3995+
self_expanded = op.Expand(self, shape) # type: ignore[arg-type]
3996+
result = op.Tile(self_expanded, repeats)
3997+
return result
39893998

39903999

39914000
def aten_repeat_interleave(

onnxscript/test/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def xfail(
7979
*,
8080
reason: str,
8181
dtypes: Optional[Collection[torch.dtype]] = None,
82-
):
82+
) -> DecorateMeta:
8383
"""Expects an OpInfo test to fail.
8484
8585
Args:
@@ -104,7 +104,7 @@ def skip(
104104
reason: str,
105105
dtypes: Optional[Collection[torch.dtype]] = None,
106106
matcher: Optional[Callable[[Any], Any]] = None,
107-
):
107+
) -> DecorateMeta:
108108
"""Skips an OpInfo test.
109109
110110
Args:
@@ -171,8 +171,8 @@ def wrapped(fn):
171171
"atanh": core_ops.aten_atanh,
172172
"bmm": core_ops.aten_bmm,
173173
"ceil": core_ops.aten_ceil,
174-
"clamp_max": core_ops.aten_clamp_max_tensor,
175-
"clamp_min": core_ops.aten_clamp_min_tensor,
174+
"clamp_max": core_ops.aten_clamp_max,
175+
"clamp_min": core_ops.aten_clamp_min,
176176
"clamp": core_ops.aten_clamp,
177177
"cos": core_ops.aten_cos,
178178
"cosh": core_ops.aten_cosh,
@@ -355,23 +355,7 @@ def wrapped(fn):
355355
)
356356

357357

358-
SKIP_SUBTESTS = (
359-
skip(
360-
"clamp_max",
361-
reason="Empty tensor not yet supported",
362-
matcher=lambda sample: sample.input.size() == torch.Size([0]),
363-
),
364-
skip(
365-
"clamp_min",
366-
reason="Empty tensor not yet supported",
367-
matcher=lambda sample: sample.input.size() == torch.Size([0]),
368-
),
369-
skip(
370-
"repeat",
371-
reason="repeating when input is a scalar and repeats is empty is not supported",
372-
matcher=lambda sample: sample.args[0] == (),
373-
),
374-
)
358+
SKIP_SUBTESTS: tuple[DecorateMeta, ...] = ()
375359
OP_WITH_SKIPPED_SUBTESTS = frozenset(meta.op_name for meta in SKIP_SUBTESTS)
376360

377361
# END OF SECTION TO MODIFY #####################################################

requirements-onnx.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Requirements for testing with the release version of onnx and onnxruntime
22

33
# TODO(#249): Fix tests for onnx 1.13
4+
numpy==1.21.5
45
onnx==1.12
56
onnxruntime

0 commit comments

Comments
 (0)