Skip to content

[torchlib] Fix pow.Tensor_Scalar type promotion #2335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6755,13 +6755,26 @@
@torch_op(("aten::pow.Tensor_Tensor", "_operator::pow"), trace_only=True)
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
"""pow(Tensor self, Tensor exponent) -> Tensor"""
# TODO(justinchuby): Add type promotion
return op.Pow(self, exponent)


@torch_op("aten::pow.Tensor_Scalar", trace_only=True)
def aten_pow_tensor_scalar(self: TReal, exponent: float) -> TReal:
"""pow(Tensor self, Scalar exponent) -> Tensor"""
return op.Pow(self, exponent)
if self.dtype.is_floating_point():
# Handle cases when e.g. (1) self is float16 or int
return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))

Check warning on line 6767 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6767

Added line #L6767 was not covered by tests
# For integer types, we need to cast self to the exponent type
if isinstance(exponent, int):
# The scalar exponent can be an int
return op.Pow(self, ir.tensor(exponent, dtype=self.dtype))

Check warning on line 6771 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6771

Added line #L6771 was not covered by tests

# exponent is float so we cast self to match the exponent type.
# More precisely if self is float64, we should cast exponent to float64; but
# this is uncommon and should be fixed when we create a general type promotion
# mechanism for torchlib
return op.Pow(op.Cast(self, to=FLOAT.dtype), exponent)

Check warning on line 6777 in onnxscript/function_libs/torch_lib/ops/core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/function_libs/torch_lib/ops/core.py#L6777

Added line #L6777 was not covered by tests


@torch_op("aten::pow.Scalar", trace_only=True)
Expand Down
8 changes: 3 additions & 5 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ class Usage(NamedTuple):
idx: int


def _short_tensor_str_for_node(x: Value) -> str:
def _short_tensor_str(x: Value) -> str:
if x.const_value is None:
return ""
if x.const_value.size <= 10:
Expand Down Expand Up @@ -1451,7 +1451,7 @@ def __str__(self) -> str:
+ ", ".join(
[
(
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str(x)}"
if x is not None
else "None"
)
Expand Down Expand Up @@ -1898,9 +1898,7 @@ def __str__(self) -> str:

# Quote the name because in reality the names can have invalid characters
# that make them hard to read
return (
f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
)
return f"%{_quoted(value_name)}<{type_text},{shape_text}>{_short_tensor_str(self)}"

def _constant_tensor_part(self) -> str:
"""Display string for the constant tensor attached to str of Value."""
Expand Down
14 changes: 14 additions & 0 deletions onnxscript/ir/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@
raise TypeError(f"Short name not available for ONNX data type: {self}")
return _DATA_TYPE_TO_SHORT_NAME[self]

def is_floating_point(self) -> bool:
"""Returns True if the data type is a floating point type."""
return self in {

Check warning on line 147 in onnxscript/ir/_enums.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/_enums.py#L147

Added line #L147 was not covered by tests
DataType.FLOAT,
DataType.FLOAT16,
DataType.DOUBLE,
DataType.BFLOAT16,
DataType.FLOAT8E4M3FN,
DataType.FLOAT8E4M3FNUZ,
DataType.FLOAT8E5M2,
DataType.FLOAT8E5M2FNUZ,
DataType.FLOAT4E2M1,
}

def __repr__(self) -> str:
return self.name

Expand Down
55 changes: 41 additions & 14 deletions tests/function_libs/torch_lib/e2e_ops_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import unittest

import onnxruntime
import torch
from torch.onnx._internal.exporter import _testing

from tests.common import testutils


class TorchLibe2eTest(testutils.TestBase):
class TorchLibe2eTest(unittest.TestCase):
def test_investigate_one_particular_model(self):
"""This test can be used to investigate a particular issue."""
red, include, stype = "amin", False, "int32"
Expand All @@ -35,19 +33,48 @@ def forward(self, x, indices, updates):
torch.tensor([[0, 0, 0], [1, 1, 1]], dtype=torch.int64),
torch.tensor([[-1, -1, -1], [-1, -1, -1]], dtype=dtype),
)
expected = model(*xs)
model_path = (
f"test_aten_scatter_{red}_{'include' if include else 'exclude'}_{stype}.onnx"
onnx_program = torch.onnx.export(model, xs, dynamo=True)
_testing.assert_onnx_program(onnx_program)

def test_pow_tensor_scalar_int_float(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**0.5

onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_pow_tensor_scalar_int_int(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**2

onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(2),), dynamo=True, optimize=False
)
_testing.assert_onnx_program(onnx_program)

def test_pow_tensor_scalar_float16_int(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**2

onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
)
torch.onnx.export(model, xs, model_path, dynamo=True)
feeds = dict(zip(["x", "indices", "updates"], [x.numpy() for x in xs]))
_testing.assert_onnx_program(onnx_program)

def test_pow_tensor_scalar_float16_float(self):
class PowModel(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x**0.5

sess_options = onnxruntime.SessionOptions()
sess = onnxruntime.InferenceSession(
model_path, sess_options=sess_options, providers=["CPUExecutionProvider"]
onnx_program = torch.onnx.export(
PowModel(), (torch.tensor(0.5, dtype=torch.float16),), dynamo=True, optimize=False
)
got = sess.run(None, feeds)[0]
torch.testing.assert_close(expected, torch.from_numpy(got), atol=1e-5, rtol=1e-5)
_testing.assert_onnx_program(onnx_program)


if __name__ == "__main__":
Expand Down
Loading