Skip to content

Commit d910fb6

Browse files
committed
[torchlib] Implement clamp* scalar overloads (#2066)
Fix issues reported in #2050 (comment)
1 parent 0f8c07d commit d910fb6

File tree

2 files changed

+53
-16
lines changed

2 files changed

+53
-16
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,13 +1662,32 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
16621662
return op.SplitToSequence(self, list_split, axis=dim)
16631663

16641664

1665-
@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
1666-
def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
1667-
"""clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
1668-
clamped = self
1665+
@torch_op("aten::clamp", trace_only=True)
1666+
def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal:
1667+
"""clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"""
16691668

16701669
if min is None and max is None:
1671-
return clamped
1670+
return op.Identity(self)
1671+
1672+
if min is not None:
1673+
min = op.CastLike(min, self)
1674+
1675+
if max is not None:
1676+
max = op.CastLike(max, self)
1677+
1678+
return op.Clip(self, min, max)
1679+
1680+
1681+
@torch_op("aten::clamp.Tensor", trace_only=True)
1682+
def aten_clamp_tensor(
1683+
self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None
1684+
) -> TReal:
1685+
"""clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
1686+
1687+
if min is None and max is None:
1688+
return op.Identity(self)
1689+
1690+
clamped = self
16721691

16731692
# If min is greater than max torch.clamp(..., min, max)
16741693
# sets all elements in input to the value of max.
@@ -1684,11 +1703,20 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] =
16841703
return clamped
16851704

16861705

1687-
@torch_op(("aten::clamp_max", "aten::clamp_max.Tensor"), trace_only=True)
1688-
def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
1689-
"""clamp_max(Tensor self, Tensor max) -> Tensor"""
1706+
@torch_op("aten::clamp_max", trace_only=True)
1707+
def aten_clamp_max(self: TReal, max_: float) -> TReal:
1708+
"""clamp_max(Tensor self, Scalar max) -> Tensor"""
1709+
1710+
# This implementation does not intend to handle when self is an empty tensor
1711+
max_ = op.CastLike(max_, self)
1712+
return op.Clip(self, None, max_)
16901713

1691-
# This implementation does not intent to handle when self is an empty tensor
1714+
1715+
@torch_op("aten::clamp_max.Tensor", trace_only=True)
1716+
def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal:
1717+
"""clamp_max.Tensor(Tensor self, Tensor max) -> Tensor"""
1718+
1719+
# This implementation does not intend to handle when self is an empty tensor
16921720
max_rank = len(max_.shape)
16931721
if max_rank == 0:
16941722
max_ = op.CastLike(max_, self)
@@ -1699,11 +1727,20 @@ def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
16991727
return result
17001728

17011729

1702-
@torch_op(("aten::clamp_min", "aten::clamp_min.Tensor"), trace_only=True)
1703-
def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
1704-
"""clamp_min(Tensor self, Tensor min) -> Tensor"""
1730+
@torch_op("aten::clamp_min", trace_only=True)
1731+
def aten_clamp_min(self: TReal, min_: float) -> TReal:
1732+
"""clamp_min(Tensor self, Scalar min) -> Tensor"""
1733+
1734+
# This implementation does not intend to handle when self is an empty tensor
1735+
min_ = op.CastLike(min_, self)
1736+
return op.Clip(self, min_, None)
1737+
1738+
1739+
@torch_op("aten::clamp_min.Tensor", trace_only=True)
1740+
def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal:
1741+
"""clamp_min.Tensor(Tensor self, Tensor min) -> Tensor"""
17051742

1706-
# This implementation does not intent to handle when self is an empty tensor
1743+
# This implementation does not intend to handle when self is an empty tensor
17071744
min_rank = len(min_.shape)
17081745
if min_rank == 0:
17091746
min_ = op.CastLike(min_, self)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,11 +717,11 @@ def _where_input_wrangler(
717717
dtypes=(torch.bool,),
718718
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
719719
),
720-
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip(
720+
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip(
721721
reason="Size 0 inputs are not handled by design",
722722
matcher=lambda sample: sample.input.numel() == 0,
723723
),
724-
TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip(
724+
TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip(
725725
reason="Size 0 inputs are not handled by design",
726726
matcher=lambda sample: sample.input.numel() == 0,
727727
),
@@ -1553,7 +1553,7 @@ def _where_input_wrangler(
15531553
variant_name="partial_views",
15541554
reason="ONNX doesn't have partial view for tensor",
15551555
),
1556-
TorchLibOpInfo("clamp", core_ops.aten_clamp),
1556+
TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor),
15571557
TorchLibOpInfo(
15581558
"ops.aten.col2im",
15591559
nn_ops.aten_col2im,

0 commit comments

Comments
 (0)