Skip to content

[torchlib] Implement clamp* scalar overloads #2066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1662,13 +1662,32 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
return op.SplitToSequence(self, list_split, axis=dim)


@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True)
def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal:
"""clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""
clamped = self
@torch_op("aten::clamp", trace_only=True)
def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal:
"""clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor"""

if min is None and max is None:
return clamped
return op.Identity(self)

if min is not None:
min = op.CastLike(min, self)

if max is not None:
max = op.CastLike(max, self)

return op.Clip(self, min, max)


@torch_op("aten::clamp.Tensor", trace_only=True)
def aten_clamp_tensor(
self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None
) -> TReal:
"""clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor"""

if min is None and max is None:
return op.Identity(self)

clamped = self

# If min is greater than max torch.clamp(..., min, max)
# sets all elements in input to the value of max.
Expand All @@ -1684,11 +1703,20 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] =
return clamped


@torch_op(("aten::clamp_max", "aten::clamp_max.Tensor"), trace_only=True)
def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
"""clamp_max(Tensor self, Tensor max) -> Tensor"""
@torch_op("aten::clamp_max", trace_only=True)
def aten_clamp_max(self: TReal, max_: float) -> TReal:
"""clamp_max(Tensor self, Scalar max) -> Tensor"""

# This implementation does not intend to handle when self is an empty tensor
max_ = op.CastLike(max_, self)
return op.Clip(self, None, max_)

# This implementation does not intent to handle when self is an empty tensor

@torch_op("aten::clamp_max.Tensor", trace_only=True)
def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal:
"""clamp_max.Tensor(Tensor self, Tensor max) -> Tensor"""

# This implementation does not intend to handle when self is an empty tensor
max_rank = len(max_.shape)
if max_rank == 0:
max_ = op.CastLike(max_, self)
Expand All @@ -1699,11 +1727,20 @@ def aten_clamp_max(self: TReal, max_: TReal) -> TReal:
return result


@torch_op(("aten::clamp_min", "aten::clamp_min.Tensor"), trace_only=True)
def aten_clamp_min(self: TReal, min_: TReal) -> TReal:
"""clamp_min(Tensor self, Tensor min) -> Tensor"""
@torch_op("aten::clamp_min", trace_only=True)
def aten_clamp_min(self: TReal, min_: float) -> TReal:
"""clamp_min(Tensor self, Scalar min) -> Tensor"""

# This implementation does not intend to handle when self is an empty tensor
min_ = op.CastLike(min_, self)
return op.Clip(self, min_, None)


@torch_op("aten::clamp_min.Tensor", trace_only=True)
def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal:
"""clamp_min.Tensor(Tensor self, Tensor min) -> Tensor"""

# This implementation does not intent to handle when self is an empty tensor
# This implementation does not intend to handle when self is an empty tensor
min_rank = len(min_.shape)
if min_rank == 0:
min_ = op.CastLike(min_, self)
Expand Down
6 changes: 3 additions & 3 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,11 +717,11 @@ def _where_input_wrangler(
dtypes=(torch.bool,),
reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905",
),
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip(
TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip(
reason="Size 0 inputs are not handled by design",
matcher=lambda sample: sample.input.numel() == 0,
),
TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip(
TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip(
reason="Size 0 inputs are not handled by design",
matcher=lambda sample: sample.input.numel() == 0,
),
Expand Down Expand Up @@ -1553,7 +1553,7 @@ def _where_input_wrangler(
variant_name="partial_views",
reason="ONNX doesn't have partial view for tensor",
),
TorchLibOpInfo("clamp", core_ops.aten_clamp),
TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor),
TorchLibOpInfo(
"ops.aten.col2im",
nn_ops.aten_col2im,
Expand Down
Loading