diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 758d87b904..1895b67fa4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -6755,13 +6755,26 @@ def aten_positive(self: TensorType) -> TensorType: @torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True) def aten_pow(self: TReal, exponent: TTensor) -> TReal: """pow(Tensor self, Tensor exponent) -> Tensor""" + # TODO(justinchuby): Add type promotion return op.Pow(self, exponent) @torch_op("aten::pow.Tensor_Scalar", trace_only=True) def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal: """pow(Tensor self, Scalar exponent) -> Tensor""" - return op.Pow(self, exponent) + if self.dtype.is_floating_point(): + # Handle cases when e.g. (1) self is float16 or int + return op.Pow(self, ir.tensor(exponent, dtype=self.dtype)) + # For integer types, we need to cast self to the exponent type + if isinstance(exponent, int): + # The scalar exponent can be an int + return op.Pow(self, ir.tensor(exponent, dtype=self.dtype)) + + # exponent is float so we cast self to match the exponent type. + # More precisely if self is float64, we should cast exponent to float64; but + # this is uncommon and should be fixed when we create a general type promotion + # mechanism for torchlib + return op.Pow(op.Cast(self, to=FLOAT.dtype), exponent) @torch_op("aten::pow.Scalar", trace_only=True) diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index 5530c4dc76..4fac12f74f 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -1266,7 +1266,7 @@ class Usage(NamedTuple): idx: int -def _short_tensor_str_for_node(x: Value) -> str: +def _short_tensor_str(x: Value) -> str: if x.const_value is None: return "" if x.const_value.size <= 10: @@ -1451,7 +1451,7 @@ def __str__(self) -> str: + ", ".join( [ ( - f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}" + f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str(x)}" if x is not None else "None" ) @@ -1898,9 +1898,7 @@ def __str__(self) -> str: # Quote the name because in reality the names can have invalid characters # that make them hard to read - return ( - f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}" - ) + return f"%{_quoted(value_name)}<{type_text},{shape_text}>{_short_tensor_str(self)}" def _constant_tensor_part(self) -> str: """Display string for the constant tensor attached to str of Value.""" diff --git a/onnxscript/ir/_enums.py b/onnxscript/ir/_enums.py index 9ecce9fed3..bcaffe66cc 100644 --- a/onnxscript/ir/_enums.py +++ b/onnxscript/ir/_enums.py @@ -142,6 +142,20 @@ def short_name(self) -> str: raise TypeError(f"Short name not available for ONNX data type: {self}") return _DATA_TYPE_TO_SHORT_NAME[self] + def is_floating_point(self) -> bool: + """Returns True if the data type is a floating point type.""" + return self in { + DataType.FLOAT, + DataType.FLOAT16, + DataType.DOUBLE, + DataType.BFLOAT16, + DataType.FLOAT8E4M3FN, + DataType.FLOAT8E4M3FNUZ, + DataType.FLOAT8E5M2, + DataType.FLOAT8E5M2FNUZ, + DataType.FLOAT4E2M1, + } + def __repr__(self) -> str: return self.name diff --git a/tests/function_libs/torch_lib/e2e_ops_tests.py b/tests/function_libs/torch_lib/e2e_ops_tests.py index e933ab8d8b..7c2978f6de 100644 --- a/tests/function_libs/torch_lib/e2e_ops_tests.py +++ b/tests/function_libs/torch_lib/e2e_ops_tests.py @@ -5,13 +5,11 @@ import unittest -import onnxruntime import torch +from torch.onnx._internal.exporter import _testing -from tests.common import testutils - -class TorchLibe2eTest(testutils.TestBase): +class TorchLibe2eTest(unittest.TestCase): def test_investigate_one_particular_model(self): """This test can be used to investigate a particular issue.""" red, include, stype = "amin", False, "int32" @@ -35,19 +33,48 @@ def forward(self, x, indices, updates): torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64), torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype), ) - expected = model(*xs) - model_path = ( - f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx" + onnx_program = torch.onnx.export(model, xs, dynamo=True) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_int_float(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**0.5 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(2),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_int_int(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**2 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(2),), dynamo=True, optimize=False + ) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_float16_int(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**2 + + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False ) - torch.onnx.export(model, xs, model_path, dynamo=True) - feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs])) + _testing.assert_onnx_program(onnx_program) + + def test_pow_tensor_scalar_float16_float(self): + class PowModel(torch.nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x**0.5 - sess_options = onnxruntime.SessionOptions() - sess = onnxruntime.InferenceSession( - model_path, sess_options=sess_options, providers=["CPUExecutionProvider"] + onnx_program = torch.onnx.export( + PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False ) - got = sess.run(None, feeds)[0] - torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5) + _testing.assert_onnx_program(onnx_program) if __name__ == "__main__":