Skip to content

Commit 109bf7f

Browse files
authored
fix(atenlib): Cast SymInt to INT64 to fix Windows CI (#292)
PyTorch can supply int32 inputs as SymInt values. This change adds explicit casts to them to fix the Windows CI. Fixed additional mypy errors #289 may be needed for python3.10
1 parent 9c75044 commit 109bf7f

File tree

6 files changed

+36
-28
lines changed

6 files changed

+36
-28
lines changed

.github/workflows/main.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ jobs:
7777
pip list | grep torch
7878
7979
- name: pytest
80-
run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n auto
80+
run: pytest -v onnxscript --cov=onnxscript --cov-report=xml -n=auto
8181

8282
- name: Install package
8383
run: pip install .
8484

8585
- name: Test examples
8686
if: ${{ matrix.test_examples }}
87-
run: pytest -v docs/test
87+
run: pytest -v docs/test -n=auto
8888

8989
- name: Build package
9090
run: python -m build

onnxscript/backend/onnx_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def _read_proto_from_file(full):
7777
loaded = to_list(seq) # type: ignore[assignment]
7878
except Exception: # pylint: disable=W0703
7979
try:
80-
loaded = onnx.load_model_from_string(serialized)
81-
except Exception: # pragma: no cover
80+
loaded = onnx.load_model_from_string(serialized) # type: ignore[assignment]
81+
except Exception:
8282
raise RuntimeError(
8383
f"Unable to read {full!r}, error is {e}, "
8484
f"content is {serialized[:100]!r}."

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from onnxscript import BOOL, DOUBLE, FLOAT, INT16, INT32, INT64
1717
from onnxscript.function_libs.torch_aten.registration import torch_op
1818
from onnxscript.function_libs.torch_aten.typing import (
19+
IntType,
1920
TFloat,
2021
TFloatOrBFloat16,
2122
TInt,
@@ -1642,10 +1643,10 @@ def aten_exp2(self: TFloat) -> TFloat:
16421643

16431644

16441645
@torch_op("aten::expand")
1645-
def aten_expand(self: TTensor, size: INT64) -> TTensor:
1646+
def aten_expand(self: TTensor, size: TInt) -> TTensor:
16461647
# expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)
16471648

1648-
size = op.Cast(size, to=INT64.dtype) # to INT64
1649+
size = op.Cast(size, to=INT64.dtype)
16491650
return op.Expand(self, size)
16501651

16511652

@@ -3518,10 +3519,11 @@ def aten_new_empty_strided(self: TensorType, size: INT64, stride: INT64) -> Tens
35183519

35193520
@torch_op("aten::new_full")
35203521
def aten_new_full(
3521-
self, size: INT64, fill_value, dtype: int = FLOAT.dtype
3522+
self, size: IntType, fill_value, dtype: int = FLOAT.dtype
35223523
): # pylint: disable=unused-argument
35233524
# new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
35243525

3526+
size = op.Cast(size, to=INT64.dtype)
35253527
fill_value = op.Cast(fill_value, to=dtype)
35263528

35273529
return op.Expand(fill_value, size)
@@ -3585,12 +3587,12 @@ def aten_nuclear_norm(self: TensorType, keepdim: bool = False) -> TensorType:
35853587

35863588

35873589
@torch_op("aten::ones")
3588-
def aten_ones(size: INT64, dtype: int = -1):
3590+
def aten_ones(size: IntType, dtype: int = FLOAT.dtype):
35893591
# ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
35903592

3593+
size = op.Cast(size, to=INT64.dtype)
35913594
one = op.Constant(value_float=1)
3592-
if dtype != -1:
3593-
one = op.Cast(one, to=dtype)
3595+
one = op.Cast(one, to=dtype)
35943596
return op.Expand(one, size)
35953597

35963598

@@ -4088,13 +4090,14 @@ def aten_renorm(self: TensorType, p: float, dim: int, maxnorm: float) -> TensorT
40884090

40894091

40904092
@torch_op("aten::repeat")
4091-
def aten_repeat(self: TTensor, repeats: INT64) -> TTensor:
4093+
def aten_repeat(self: TTensor, repeats: TInt) -> TTensor:
40924094
# repeat(Tensor self, SymInt[] repeats) -> Tensor
40934095

40944096
if op.Size(repeats) == 0:
40954097
result = self
40964098
else:
40974099
# TODO(justinchuby): Make ones_like a function when onnxscript supports it
4100+
repeats = op.Cast(repeats, to=INT64.dtype)
40984101
# shape = ones_like(repeats) := {
40994102
one = op.Constant(value_int=1)
41004103
repeats_shape = op.Shape(repeats)
@@ -4114,10 +4117,11 @@ def aten_repeat_interleave(
41144117

41154118

41164119
@torch_op("aten::reshape")
4117-
def aten_reshape(self: TTensor, shape: INT64) -> TTensor:
4120+
def aten_reshape(self: TTensor, shape: IntType) -> TTensor:
41184121
# reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)
41194122

4120-
shape = op.Cast(shape, to=INT64.dtype) # Reshape only support INT64 as 'shape'
4123+
# Reshape only support INT64 as 'shape'
4124+
shape = op.Cast(shape, to=INT64.dtype)
41214125
return op.Reshape(self, shape)
41224126

41234127

@@ -4975,7 +4979,7 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
49754979

49764980

49774981
@torch_op("aten::view")
4978-
def aten_view(self: TTensor, size: INT64) -> TTensor:
4982+
def aten_view(self: TTensor, size: IntType) -> TTensor:
49794983
# view(Tensor(a) self, SymInt[] size) -> Tensor(a)
49804984

49814985
size = op.Cast(size, to=INT64.dtype) # Reshape only support INT64 as second input
@@ -5044,12 +5048,12 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
50445048

50455049

50465050
@torch_op("aten::zeros")
5047-
def aten_zeros(size: INT64, dtype: int = -1):
5051+
def aten_zeros(size: IntType, dtype: int = FLOAT.dtype):
50485052
# zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
50495053

5054+
size = op.Cast(size, to=INT64.dtype)
50505055
zero = op.Constant(value_float=0)
5051-
if dtype != -1:
5052-
zero = op.Cast(zero, to=dtype)
5056+
zero = op.Cast(zero, to=dtype)
50535057

50545058
return op.Expand(zero, size)
50555059

onnxscript/function_libs/torch_aten/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
UINT8,
4242
]
4343
_FloatType = Union[FLOAT16, FLOAT, DOUBLE]
44-
_IntType = Union[INT8, INT16, INT32, INT64]
44+
IntType = Union[INT8, INT16, INT32, INT64]
4545
RealType = Union[
4646
BFLOAT16,
4747
FLOAT16,
@@ -56,7 +56,7 @@
5656
TTensor = TypeVar("TTensor", bound=_TensorType)
5757
TFloat = TypeVar("TFloat", bound=_FloatType)
5858
TFloatOrBFloat16 = TypeVar("TFloatOrBFloat16", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16])
59-
TInt = TypeVar("TInt", bound=_IntType)
59+
TInt = TypeVar("TInt", bound=IntType)
6060
TReal = TypeVar("TReal", bound=RealType)
6161
TRealUnlessInt16OrInt8 = TypeVar(
6262
"TRealUnlessInt16OrInt8", bound=Union[FLOAT16, FLOAT, DOUBLE, BFLOAT16, INT32, INT64]

onnxscript/utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from __future__ import annotations
66

77
import numbers
8-
from typing import Any, Optional, Sequence
8+
from typing import Any, Iterable, Optional, Sequence
99

1010
import numpy as np
1111
import onnx
12+
import onnx.helper
13+
import onnx.mapping
1214
from onnx import FunctionProto, ModelProto, TensorProto, ValueInfoProto
13-
from onnx.helper import make_sequence_type_proto, make_tensor_type_proto
1415

1516
from onnxscript import tensor
1617

@@ -82,22 +83,24 @@ def value_to_type_proto(val):
8283
if isinstance(val, (np.ndarray, tensor.Tensor)):
8384
elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[val.dtype]
8485
shape = val.shape
85-
return make_tensor_type_proto(elem_type, shape)
86+
return onnx.helper.make_tensor_type_proto(elem_type, shape)
8687
if isinstance(val, int):
87-
return make_tensor_type_proto(TensorProto.INT32, [])
88+
return onnx.helper.make_tensor_type_proto(TensorProto.INT32, [])
8889
if isinstance(val, (float, np.float32)):
89-
return make_tensor_type_proto(TensorProto.FLOAT, [])
90+
return onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, [])
9091
if isinstance(val, list):
9192
if len(val) > 0:
92-
return make_sequence_type_proto(value_to_type_proto(val[0]))
93+
return onnx.helper.make_sequence_type_proto(value_to_type_proto(val[0]))
9394
# Edge-case. Cannot determine a suitable ONNX type for an empty list.
9495
# Should be using a typed-value instead.
9596
# Treated as a sequence of tensors of float-type.
96-
return make_sequence_type_proto(make_tensor_type_proto(TensorProto.FLOAT, None))
97+
return onnx.helper.make_sequence_type_proto(
98+
onnx.helper.make_tensor_type_proto(TensorProto.FLOAT, None)
99+
)
97100
if isinstance(val, numbers.Number):
98101
nparray = np.array(val)
99102
elem_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[nparray.dtype]
100-
return make_tensor_type_proto(elem_type, [])
103+
return onnx.helper.make_tensor_type_proto(elem_type, [])
101104
raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.")
102105

103106

@@ -144,7 +147,7 @@ def make_model_from_function_proto(
144147
**(attrs or {}),
145148
)
146149
graph = onnx.helper.make_graph([node], "node_graph", input_value_infos, output_value_infos)
147-
model_proto_opset = function_proto.opset_import
150+
model_proto_opset: Iterable[onnx.OperatorSetIdProto] = function_proto.opset_import
148151
if all(o.domain != function_proto.domain for o in model_proto_opset):
149152
model_proto_opset = [
150153
*model_proto_opset,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ onnx = ["py.typed"]
4545

4646
[tool.pytest.ini_options]
4747
filterwarnings = ["ignore::UserWarning", "ignore::DeprecationWarning"]
48+
addopts = "-ra --tb=short --color=yes"
4849

4950
[tool.mypy]
5051
follow_imports = "silent" # TODO: Remove when we fix all the mypy errors

0 commit comments

Comments
 (0)