From fd1398e10ce49d8f78a248b5dadb805b342118fe Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Feb 2025 08:51:10 -0800 Subject: [PATCH 1/2] test --- .../function_libs/torch_lib/ops/core.py | 59 +++++++++++++++---- .../function_libs/torch_lib/ops_test_data.py | 6 +- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f20c96ec41..df65e3f5c4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1662,13 +1662,32 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, list_split, axis=dim) -@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True) -def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal: - """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" - clamped = self +@torch_op("aten::clamp", trace_only=True) +def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal: + """clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor""" if min is None and max is None: - return clamped + return op.Identity(self) + + if min is not None: + min = op.CastLike(min, self) + + if max is not None: + max = op.CastLike(max, self) + + return op.Clip(self, min, max) + + +@torch_op("aten::clamp.Tensor", trace_only=True) +def aten_clamp_tensor( + self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None +) -> TReal: + """clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" + + if min is None and max is None: + return op.Identity(self) + + clamped = self # If min is greater than max torch.clamp(..., min, max) # sets all elements in input to the value of max. @@ -1684,9 +1703,18 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = return clamped -@torch_op(("aten::clamp_max", "aten::clamp_max.Tensor"), trace_only=True) -def aten_clamp_max(self: TReal, max_: TReal) -> TReal: - """clamp_max(Tensor self, Tensor max) -> Tensor""" +@torch_op("aten::clamp_max", trace_only=True) +def aten_clamp_max(self: TReal, max_: float) -> TReal: + """clamp_max(Tensor self, Scalar max) -> Tensor""" + + # This implementation does not intent to handle when self is an empty tensor + max_ = op.CastLike(max_, self) + return op.Clip(self, None, max_) + + +@torch_op("aten::clamp_max.Tensor", trace_only=True) +def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal: + """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor""" # This implementation does not intent to handle when self is an empty tensor max_rank = len(max_.shape) @@ -1699,9 +1727,18 @@ def aten_clamp_max(self: TReal, max_: TReal) -> TReal: return result -@torch_op(("aten::clamp_min", "aten::clamp_min.Tensor"), trace_only=True) -def aten_clamp_min(self: TReal, min_: TReal) -> TReal: - """clamp_min(Tensor self, Tensor min) -> Tensor""" +@torch_op("aten::clamp_min", trace_only=True) +def aten_clamp_min(self: TReal, min_: float) -> TReal: + """clamp_min(Tensor self, Scalar min) -> Tensor""" + + # This implementation does not intent to handle when self is an empty tensor + min_ = op.CastLike(min_, self) + return op.Clip(self, min_, None) + + +@torch_op("aten::clamp_min.Tensor", trace_only=True) +def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal: + """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor""" # This implementation does not intent to handle when self is an empty tensor min_rank = len(min_.shape) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3244ebd219..c6b52be0c5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -717,11 +717,11 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip( + TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, ), - TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip( + TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, ), @@ -1553,7 +1553,7 @@ def _where_input_wrangler( variant_name="partial_views", reason="ONNX doesn't have partial view for tensor", ), - TorchLibOpInfo("clamp", core_ops.aten_clamp), + TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", nn_ops.aten_col2im, From 17799219002240f22bbdb2cea63e9ced25df262c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 20 Feb 2025 08:58:42 -0800 Subject: [PATCH 2/2] typo --- onnxscript/function_libs/torch_lib/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index df65e3f5c4..6218d7ae9f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1707,7 +1707,7 @@ def aten_clamp_tensor( def aten_clamp_max(self: TReal, max_: float) -> TReal: """clamp_max(Tensor self, Scalar max) -> Tensor""" - # This implementation does not intent to handle when self is an empty tensor + # This implementation does not intend to handle when self is an empty tensor max_ = op.CastLike(max_, self) return op.Clip(self, None, max_) @@ -1716,7 +1716,7 @@ def aten_clamp_max(self: TReal, max_: float) -> TReal: def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal: """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor""" - # This implementation does not intent to handle when self is an empty tensor + # This implementation does not intend to handle when self is an empty tensor max_rank = len(max_.shape) if max_rank == 0: max_ = op.CastLike(max_, self) @@ -1731,7 +1731,7 @@ def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal: def aten_clamp_min(self: TReal, min_: float) -> TReal: """clamp_min(Tensor self, Scalar min) -> Tensor""" - # This implementation does not intent to handle when self is an empty tensor + # This implementation does not intend to handle when self is an empty tensor min_ = op.CastLike(min_, self) return op.Clip(self, min_, None) @@ -1740,7 +1740,7 @@ def aten_clamp_min(self: TReal, min_: float) -> TReal: def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal: """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor""" - # This implementation does not intent to handle when self is an empty tensor + # This implementation does not intend to handle when self is an empty tensor min_rank = len(min_.shape) if min_rank == 0: min_ = op.CastLike(min_, self)