Skip to content

Commit 61d4ab5

Browse files
authored
[torchlib] Fix pow.Tensor_Scalar type promotion (#2335)
Fix pow.Tensor_Scalar type promotion by accounting different combination of input dtypes. This change ensures the inputs to Pow is always the same type for compatibility with downstream tools. Also - Added is_floating_point for dtype for convienience. The method naming follows https://docs.pytorch.org/docs/stable/generated/torch.is_floating_point.html - Simplify value str when it is constant. Fix #2213
1 parent 8c0046f commit 61d4ab5

File tree

4 files changed

+72
-20
lines changed

4 files changed

+72
-20
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6755,13 +6755,26 @@ def aten_positive(self: TensorType) -> TensorType:
67556755
@torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True)
67566756
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
67576757
"""pow(Tensor self, Tensor exponent) -> Tensor"""
6758+
# TODO(justinchuby): Add type promotion
67586759
return op.Pow(self, exponent)
67596760

67606761

67616762
@torch_op("aten::pow.Tensor_Scalar", trace_only=True)
67626763
def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal:
67636764
"""pow(Tensor self, Scalar exponent) -> Tensor"""
6764-
return op.Pow(self, exponent)
6765+
if self.dtype.is_floating_point():
6766+
# Handle cases when e.g. (1) self is float16 or int
6767+
return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))
6768+
# For integer types, we need to cast self to the exponent type
6769+
if isinstance(exponent, int):
6770+
# The scalar exponent can be an int
6771+
return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))
6772+
6773+
# exponent is float so we cast self to match the exponent type.
6774+
# More precisely if self is float64, we should cast exponent to float64; but
6775+
# this is uncommon and should be fixed when we create a general type promotion
6776+
# mechanism for torchlib
6777+
return op.Pow(op.Cast(self, to=FLOAT.dtype), exponent)
67656778

67666779

67676780
@torch_op("aten::pow.Scalar", trace_only=True)

onnxscript/ir/_core.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,7 +1266,7 @@ class Usage(NamedTuple):
12661266
idx: int
12671267

12681268

1269-
def _short_tensor_str_for_node(x: Value) -> str:
1269+
def _short_tensor_str(x: Value) -> str:
12701270
if x.const_value is None:
12711271
return ""
12721272
if x.const_value.size <= 10:
@@ -1451,7 +1451,7 @@ def __str__(self) -> str:
14511451
+ ", ".join(
14521452
[
14531453
(
1454-
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
1454+
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str(x)}"
14551455
if x is not None
14561456
else "None"
14571457
)
@@ -1898,9 +1898,7 @@ def __str__(self) -> str:
18981898

18991899
# Quote the name because in reality the names can have invalid characters
19001900
# that make them hard to read
1901-
return (
1902-
f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
1903-
)
1901+
return f"%{_quoted(value_name)}<{type_text},{shape_text}>{_short_tensor_str(self)}"
19041902

19051903
def _constant_tensor_part(self) -> str:
19061904
"""Display string for the constant tensor attached to str of Value."""

onnxscript/ir/_enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,20 @@ def short_name(self) -> str:
142142
raise TypeError(f"Short name not available for ONNX data type: {self}")
143143
return _DATA_TYPE_TO_SHORT_NAME[self]
144144

145+
def is_floating_point(self) -> bool:
146+
"""Returns True if the data type is a floating point type."""
147+
return self in {
148+
DataType.FLOAT,
149+
DataType.FLOAT16,
150+
DataType.DOUBLE,
151+
DataType.BFLOAT16,
152+
DataType.FLOAT8E4M3FN,
153+
DataType.FLOAT8E4M3FNUZ,
154+
DataType.FLOAT8E5M2,
155+
DataType.FLOAT8E5M2FNUZ,
156+
DataType.FLOAT4E2M1,
157+
}
158+
145159
def __repr__(self) -> str:
146160
return self.name
147161

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55

66
import unittest
77

8-
import onnxruntime
98
import torch
9+
from torch.onnx._internal.exporter import _testing
1010

11-
from tests.common import testutils
1211

13-
14-
class TorchLibe2eTest(testutils.TestBase):
12+
class TorchLibe2eTest(unittest.TestCase):
1513
def test_investigate_one_particular_model(self):
1614
"""This test can be used to investigate a particular issue."""
1715
red, include, stype = "amin", False, "int32"
@@ -35,19 +33,48 @@ def forward(self, x, indices, updates):
3533
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
3634
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
3735
)
38-
expected = model(*xs)
39-
model_path = (
40-
f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx"
36+
onnx_program = torch.onnx.export(model, xs, dynamo=True)
37+
_testing.assert_onnx_program(onnx_program)
38+
39+
def test_pow_tensor_scalar_int_float(self):
40+
class PowModel(torch.nn.Module):
41+
def forward(self, x: torch.Tensor) -> torch.Tensor:
42+
return x**0.5
43+
44+
onnx_program = torch.onnx.export(
45+
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
46+
)
47+
_testing.assert_onnx_program(onnx_program)
48+
49+
def test_pow_tensor_scalar_int_int(self):
50+
class PowModel(torch.nn.Module):
51+
def forward(self, x: torch.Tensor) -> torch.Tensor:
52+
return x**2
53+
54+
onnx_program = torch.onnx.export(
55+
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
56+
)
57+
_testing.assert_onnx_program(onnx_program)
58+
59+
def test_pow_tensor_scalar_float16_int(self):
60+
class PowModel(torch.nn.Module):
61+
def forward(self, x: torch.Tensor) -> torch.Tensor:
62+
return x**2
63+
64+
onnx_program = torch.onnx.export(
65+
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
4166
)
42-
torch.onnx.export(model, xs, model_path, dynamo=True)
43-
feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs]))
67+
_testing.assert_onnx_program(onnx_program)
68+
69+
def test_pow_tensor_scalar_float16_float(self):
70+
class PowModel(torch.nn.Module):
71+
def forward(self, x: torch.Tensor) -> torch.Tensor:
72+
return x**0.5
4473

45-
sess_options = onnxruntime.SessionOptions()
46-
sess = onnxruntime.InferenceSession(
47-
model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]
74+
onnx_program = torch.onnx.export(
75+
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
4876
)
49-
got = sess.run(None, feeds)[0]
50-
torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5)
77+
_testing.assert_onnx_program(onnx_program)
5178

5279

5380
if __name__ == "__main__":

0 commit comments

Comments
 (0)