Skip to content

Commit 95a24dd

Browse files
committed
Merge branch 'main' into xiaowu/fixBug(embedding_bag)
2 parents 2666f56 + 10f9a1f commit 95a24dd

33 files changed

+1144
-279
lines changed

.lintrunner.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ init_command = [
7070
]
7171

7272
[[linter]]
73-
code = 'BLACK-ISORT'
73+
code = 'RUFF-FORMAT'
7474
include_patterns = [
7575
'**/*.py',
7676
]
@@ -82,7 +82,7 @@ command = [
8282
'-m',
8383
'lintrunner_adapters',
8484
'run',
85-
'black_isort_linter',
85+
'ruff_format_linter',
8686
'--',
8787
'@{{PATHSFILE}}'
8888
]

docs/examples/04_plot_eager_mode_evaluation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818

1919

2020
@script()
21-
def linear(
22-
A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"]
23-
) -> FLOAT["N", "M"]: # noqa: F821
21+
def linear(A: FLOAT["N", "K"], W: FLOAT["K", "M"], Bias: FLOAT["M"]) -> FLOAT["N", "M"]: # noqa: F821
2422
T1 = op.MatMul(A, W)
2523
T2 = op.Add(T1, Bias)
2624
Y = op.Relu(T2)

noxfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
"pyyaml",
2828
)
2929
ONNX = "onnx==1.14.1"
30-
ONNX_RUNTIME = "onnxruntime==1.16.0"
31-
PYTORCH = "torch==2.0.1"
30+
ONNX_RUNTIME = "onnxruntime==1.16.1"
31+
PYTORCH = "torch==2.1.0"
3232
ONNX_RUNTIME_NIGHTLY_DEPENDENCIES = (
3333
"flatbuffers",
3434
"coloredlogs",

onnxscript/_internal/param_manipulation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def separate_input_attributes_from_arguments(
6161
else:
6262
onnx_attributes[param.name] = kwargs[param.name]
6363
elif (
64-
param.is_attribute
65-
and param.default is not values._EmptyDefault # pylint: disable=protected-access
64+
param.is_attribute and param.default is not values._EmptyDefault # pylint: disable=protected-access
6665
):
6766
# User did not provide the attribute
6867
if fill_defaults:

onnxscript/converter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
import onnx
2222

2323
import onnxscript
24-
from onnxscript import irbuilder, onnx_types, sourceinfo
24+
from onnxscript import irbuilder, onnx_types, sourceinfo, values
2525
from onnxscript import type_annotation as ta
26-
from onnxscript import values
2726
from onnxscript._internal import analysis, ast_utils, autocast, param_manipulation
2827

2928
PY_VERSION_GE_39 = ast_utils.PY_VERSION_GE_39

onnxscript/converter_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numpy as np
1818
import onnx
1919
import onnxruntime as ort
20+
import pytest
2021
from onnxruntime.capi.onnxruntime_pybind11_state import (
2122
Fail,
2223
InvalidArgument,
@@ -270,7 +271,10 @@ def test_renaming(self):
270271

271272
self.validate_save(renaming, shape_inference=False)
272273

273-
@unittest.skip(reason="TypeError: val must be numeric not <class 'NoneType'>")
274+
@pytest.mark.xfail(
275+
strict=True,
276+
reason="default_opset must be specified in script for functions that do not contain any use of an ONNX op",
277+
)
274278
def test_opt_output(self):
275279
from onnxscript.tests.models import opt_output
276280

onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ class TestDeduceTypeConstraints(unittest.TestCase):
3030
"_aten_embedding_bag_onnx",
3131
"_aten_embedding_bag_1d_padding_idx_onnx",
3232
)
33-
_SKIP_FUNCTIONS_WITH_NESTED_FUNCTION = ()
3433

3534
@parameterized.parameterized.expand(
3635
((op,) for op in torch_lib_onnx_functions_from_registry()),
@@ -41,11 +40,13 @@ def test_deduce_type_constraints_does_not_crash_for_onnx_function(
4140
):
4241
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_LOOP_OR_SCAN:
4342
self.skipTest("Unimplemented: function contains loop or scan node.")
44-
if onnx_function.name in self._SKIP_FUNCTIONS_WITH_NESTED_FUNCTION:
45-
self.skipTest("Unimplemented: function contains nested function.")
46-
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
47-
onnx_function
48-
)
43+
try:
44+
signature_type_constraint = deduce_type_constraints.deduce_type_constraints(
45+
onnx_function
46+
)
47+
except NotImplementedError as e:
48+
if "Nested function" in str(e):
49+
self.skipTest("Unimplemented: function contains nested function.")
4950
logger.info(
5051
"Original signature: %s%s",
5152
onnx_function.name,

onnxscript/function_libs/tools/torch_lib/generate_prims_signatures.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
import os
1313
import re
1414
import textwrap
15-
from pathlib import Path
1615
from typing import Any, Dict, List, Sequence
1716

18-
import black
19-
import isort
2017
import torch
2118
import torchgen.gen
2219
import torchgen.model
@@ -319,15 +316,6 @@ def main(args: argparse.Namespace) -> None:
319316
)
320317
py_module.accept(cg.PythonWriter(f))
321318

322-
# Format the generated files so that they pass linting.
323-
# line_length=95 is to match the lintrunner rules.
324-
isort.file(output_path)
325-
black.format_file_in_place(
326-
Path(output_path),
327-
fast=True,
328-
mode=black.Mode(line_length=95),
329-
write_back=black.WriteBack.YES,
330-
)
331319
print("Done.")
332320

333321

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
"""Shared constants for the library."""
2+
3+
DOMAIN = "pkg.onnxscript.torch_lib"
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Experimental flags.
2+
3+
NOTE: These flags are experimental only. Any flag here can be removed at any
4+
time without notice.
5+
"""
6+
7+
import logging
8+
import os
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def _load_boolean_flag(
14+
name: str,
15+
*,
16+
this_will: str,
17+
deprecated: bool = False,
18+
) -> bool:
19+
"""Load a boolean flag from environment variable.
20+
21+
Args:
22+
name: The name of the environment variable.
23+
this_will: A string that describes what this flag will do.
24+
deprecated: Whether this flag is deprecated.
25+
"""
26+
state = os.getenv(name) == "1"
27+
if state:
28+
if deprecated:
29+
logger.error(
30+
"Experimental flag %s is deprecated. Please remove it from your environment.",
31+
name,
32+
)
33+
else:
34+
logger.warning("Experimental flag %s is enabled. This will %s.", name, this_will)
35+
return state
36+
37+
38+
EXPERIMENTAL_INITIALIZERS_AS_INPUTS: bool = _load_boolean_flag(
39+
"TORCHLIB_EXPERIMENTAL_INITIALIZERS_AS_INPUTS",
40+
this_will="make initializers as inputs to the model graph",
41+
)

onnxscript/function_libs/torch_lib/graph_building.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from onnxscript import evaluator
2222
from onnxscript import tensor as onnxscript_tensor
2323
from onnxscript._internal import param_manipulation, runtime_typing
24+
from onnxscript.function_libs.torch_lib import _flags
25+
from onnxscript.function_libs.torch_lib.ops import common as common_ops
2426

2527
__all__ = [
2628
"TorchScriptTensor",
@@ -198,7 +200,7 @@ def symbolic_value(self) -> torch.Value:
198200
def _unwrap_tensor_to_torch_value(
199201
value: Union[
200202
ValidArgumentType, Mapping[str, ValidArgumentType], Sequence[ValidArgumentType]
201-
]
203+
],
202204
) -> Union[
203205
ValidTorchValueType,
204206
Dict[str, ValidTorchValueType],
@@ -363,6 +365,16 @@ def _tensor_rawdata_size(tensor: torch.Tensor) -> int:
363365
return tensor.numel() * tensor.element_size()
364366

365367

368+
def _shared_functions() -> list[onnx.FunctionProto]:
369+
"""Hack to always include the share ops."""
370+
371+
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
372+
return [
373+
common_ops.Rank.to_function_proto(),
374+
common_ops.IsScalar.to_function_proto(),
375+
]
376+
377+
366378
class TorchScriptGraph:
367379
def __init__(
368380
self,
@@ -717,7 +729,6 @@ def to_function_proto(self, opset_version: int, function_name: str) -> onnx.Func
717729
opset_imports=onnx_model.opset_import,
718730
doc_string=onnx_model.doc_string,
719731
)
720-
# TODO: onnx.checker.check_function(onnx_function)?
721732
return onnx_function
722733

723734
@runtime_typing.checked
@@ -740,13 +751,15 @@ def to_model_proto(
740751
large_model = initializers_size > _LARGE_MODEL_SIZE_THRESHOLD
741752

742753
export_kwargs: dict[str, Any] = dict(
743-
initializers=self.initializers if include_initializers else {},
754+
initializers=self.initializers
755+
if include_initializers and not _flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS
756+
else {},
744757
onnx_opset_version=opset_version,
745758
dynamic_axes={},
746759
defer_weight_export=False,
747760
operator_export_type=torch.onnx.OperatorExportTypes.ONNX,
748761
strip_doc_string=False,
749-
keep_initializers_as_inputs=False,
762+
keep_initializers_as_inputs=_flags.EXPERIMENTAL_INITIALIZERS_AS_INPUTS,
750763
custom_opsets={},
751764
add_node_names=True,
752765
node_attr_to_name={},
@@ -786,6 +799,7 @@ def to_model_proto(
786799
onnx_model = onnx.load_from_string(proto)
787800

788801
onnx_model.functions.extend(function_proto_dict.values())
802+
onnx_model.functions.extend(_shared_functions())
789803

790804
# `_export_onnx` only exports opset_imports that is visible to it. It does not
791805
# export opset_imports for nested functions, since it does not have access to
@@ -800,6 +814,13 @@ def to_model_proto(
800814
for domain, version in unique_custom_domains.items()
801815
]
802816
)
817+
# Include the library shared opset domain
818+
# TODO: Remove after https://github.com/microsoft/onnxscript/issues/834 is fixed
819+
onnx_model.opset_import.append(
820+
onnx.helper.make_opsetid(
821+
common_ops.common_opset.domain, common_ops.common_opset.version
822+
)
823+
)
803824

804825
try:
805826
if not cache_model_to_disk:

onnxscript/function_libs/torch_lib/graph_building_test.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
import onnxscript.testing
1212
from onnxscript import FLOAT, evaluator
1313
from onnxscript import opset18 as op
14-
from onnxscript._internal import version_utils
1514
from onnxscript.function_libs.torch_lib import graph_building, ops
1615

1716

18-
@unittest.skipIf(version_utils.torch_older_than("2.0"), "torchscript in 1.13 not supported")
1917
class TestTorchScriptTracingEvaluator(unittest.TestCase):
2018
def setUp(self):
2119
self.opset_version = 18
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
"""Common operators shared in the torchlib library."""
2+
3+
import onnxscript
4+
import onnxscript.values
5+
from onnxscript import BOOL, INT64
6+
from onnxscript import opset18 as op
7+
from onnxscript.function_libs.torch_lib import _constants, tensor_typing
8+
from onnxscript.function_libs.torch_lib.tensor_typing import RealType
9+
from onnxscript.onnx_types import COMPLEX64, COMPLEX128, DOUBLE, FLOAT
10+
11+
COMPLEX64_TYPE = COMPLEX64.dtype
12+
COMPLEX128_TYPE = COMPLEX128.dtype
13+
14+
DOMAIN = f"{_constants.DOMAIN}.common"
15+
16+
common_opset = onnxscript.values.Opset(domain=DOMAIN, version=1)
17+
18+
19+
@onnxscript.script(common_opset)
20+
def Rank(input: tensor_typing.TTensor) -> INT64:
21+
"""Take the rank of the input tensor."""
22+
23+
return op.Size(op.Shape(input))
24+
25+
26+
@onnxscript.script(common_opset)
27+
def IsScalar(input: tensor_typing.TTensor) -> BOOL:
28+
"""Return whether the input has rank 0, or is a scalar."""
29+
30+
return op.Equal(op.Size(op.Shape(input)), op.Constant(value_int=0))
31+
32+
33+
def cast_to(a: RealType, dtype: int) -> RealType:
34+
"""Cast input to dtype while handling complex types."""
35+
36+
# Traced function because different if branches return different dtypes
37+
# which is not supported in an ONNX function
38+
if dtype == COMPLEX128_TYPE:
39+
# Cast to the real representation of the complex type
40+
casted = op.Cast(a, to=DOUBLE.dtype)
41+
# Create a complex number
42+
real_part = op.Unsqueeze(casted, axes=[-1])
43+
imag_part = op.Expand(op.Cast(0.0, to=DOUBLE.dtype), op.Shape(real_part))
44+
result = op.Concat(real_part, imag_part, axis=-1)
45+
elif dtype == COMPLEX64_TYPE:
46+
# Cast to the real representation of the complex type
47+
casted = op.Cast(a, to=FLOAT.dtype)
48+
# Create a complex number
49+
real_part = op.Unsqueeze(casted, axes=[-1])
50+
imag_part = op.Expand(0.0, op.Shape(real_part))
51+
result = op.Concat(real_part, imag_part, axis=-1)
52+
else:
53+
# Cast to real numbers
54+
result = op.Cast(a, to=dtype)
55+
56+
return result

0 commit comments

Comments
 (0)