diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index f20c96ec41..6218d7ae9f 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -1662,13 +1662,32 @@ def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]: return op.SplitToSequence(self, list_split, axis=dim) -@torch_op(("aten::clamp", "aten::clamp.Tensor"), trace_only=True) -def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None) -> TReal: - """clamp(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" - clamped = self +@torch_op("aten::clamp", trace_only=True) +def aten_clamp(self: TReal, min: Optional[float] = None, max: Optional[float] = None) -> TReal: + """clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor""" if min is None and max is None: - return clamped + return op.Identity(self) + + if min is not None: + min = op.CastLike(min, self) + + if max is not None: + max = op.CastLike(max, self) + + return op.Clip(self, min, max) + + +@torch_op("aten::clamp.Tensor", trace_only=True) +def aten_clamp_tensor( + self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = None +) -> TReal: + """clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor""" + + if min is None and max is None: + return op.Identity(self) + + clamped = self # If min is greater than max torch.clamp(..., min, max) # sets all elements in input to the value of max. @@ -1684,11 +1703,20 @@ def aten_clamp(self: TReal, min: Optional[TReal] = None, max: Optional[TReal] = return clamped -@torch_op(("aten::clamp_max", "aten::clamp_max.Tensor"), trace_only=True) -def aten_clamp_max(self: TReal, max_: TReal) -> TReal: - """clamp_max(Tensor self, Tensor max) -> Tensor""" +@torch_op("aten::clamp_max", trace_only=True) +def aten_clamp_max(self: TReal, max_: float) -> TReal: + """clamp_max(Tensor self, Scalar max) -> Tensor""" + + # This implementation does not intend to handle when self is an empty tensor + max_ = op.CastLike(max_, self) + return op.Clip(self, None, max_) - # This implementation does not intent to handle when self is an empty tensor + +@torch_op("aten::clamp_max.Tensor", trace_only=True) +def aten_clamp_max_tensor(self: TReal, max_: TReal) -> TReal: + """clamp_max.Tensor(Tensor self, Tensor max) -> Tensor""" + + # This implementation does not intend to handle when self is an empty tensor max_rank = len(max_.shape) if max_rank == 0: max_ = op.CastLike(max_, self) @@ -1699,11 +1727,20 @@ def aten_clamp_max(self: TReal, max_: TReal) -> TReal: return result -@torch_op(("aten::clamp_min", "aten::clamp_min.Tensor"), trace_only=True) -def aten_clamp_min(self: TReal, min_: TReal) -> TReal: - """clamp_min(Tensor self, Tensor min) -> Tensor""" +@torch_op("aten::clamp_min", trace_only=True) +def aten_clamp_min(self: TReal, min_: float) -> TReal: + """clamp_min(Tensor self, Scalar min) -> Tensor""" + + # This implementation does not intend to handle when self is an empty tensor + min_ = op.CastLike(min_, self) + return op.Clip(self, min_, None) + + +@torch_op("aten::clamp_min.Tensor", trace_only=True) +def aten_clamp_min_tensor(self: TReal, min_: TReal) -> TReal: + """clamp_min.Tensor(Tensor self, Tensor min) -> Tensor""" - # This implementation does not intent to handle when self is an empty tensor + # This implementation does not intend to handle when self is an empty tensor min_rank = len(min_.shape) if min_rank == 0: min_ = op.CastLike(min_, self) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 3244ebd219..c6b52be0c5 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -717,11 +717,11 @@ def _where_input_wrangler( dtypes=(torch.bool,), reason="fixme: ORT does not implement SplitToSequence for bool inputs: https://github.com/microsoft/onnxruntime/issues/16905", ), - TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max).skip( + TorchLibOpInfo("clamp_max", core_ops.aten_clamp_max_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, ), - TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min).skip( + TorchLibOpInfo("clamp_min", core_ops.aten_clamp_min_tensor).skip( reason="Size 0 inputs are not handled by design", matcher=lambda sample: sample.input.numel() == 0, ), @@ -1553,7 +1553,7 @@ def _where_input_wrangler( variant_name="partial_views", reason="ONNX doesn't have partial view for tensor", ), - TorchLibOpInfo("clamp", core_ops.aten_clamp), + TorchLibOpInfo("clamp", core_ops.aten_clamp_tensor), TorchLibOpInfo( "ops.aten.col2im", nn_ops.aten_col2im,