Skip to content

Commit 85d886d

Browse files
committed
Update base for Update on "Implement the experimental evaluator for folding branches and castlikes | feat(torchlib)"
As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
2 parents 2ab2b19 + 744cabd commit 85d886d

File tree

7 files changed

+182
-37
lines changed

7 files changed

+182
-37
lines changed

onnxscript/backend/onnx_export.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,9 @@ def _rename_variable_s(self, name):
258258
return str(self._rename_variable(name))
259259

260260
def _rename_domain(self, domain: str) -> str:
261-
if domain == "":
262-
return "opset"
263-
return domain.replace(".", "_")
261+
if domain in {"", "ai.onnx"}:
262+
return "opset" # TODO: Need checks to avoid name conflicts.
263+
return _cleanup_variable_name(domain) # type: ignore[return-value]
264264

265265
def _make_opset_name(self, domain, version):
266266
return f"{self._rename_domain(domain)}{version}"
@@ -552,11 +552,13 @@ def add_line(line: str) -> None:
552552
add_line(f" return {return_values}")
553553
return "\n".join(result)
554554

555-
def _translate_graph(self, model: onnx.ModelProto, function_name: str) -> str:
555+
def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str]) -> str:
556556
graph = model.graph
557557
opsets = {}
558558
for imported in model.opset_import:
559559
opsets[imported.domain] = imported.version
560+
if function_name is None:
561+
function_name = _cleanup_variable_name(graph.name)
560562

561563
result: list[str] = []
562564

@@ -593,7 +595,9 @@ def _import_onnx_types(
593595
return "from onnxscript.onnx_types import " + ", ".join(sorted_types)
594596
return ""
595597

596-
def export(self, proto: onnx.ModelProto | onnx.FunctionProto, function_name: str) -> str:
598+
def export(
599+
self, proto: onnx.ModelProto | onnx.FunctionProto, function_name: Optional[str]
600+
) -> str:
597601
result: list[str] = []
598602

599603
def add(line: str) -> None:
@@ -612,7 +616,6 @@ def add(line: str) -> None:
612616
translated_functions.append(self._translate_graph(proto, function_name))
613617
else:
614618
assert isinstance(proto, FunctionProto)
615-
# TODO: use function_name?
616619
translated_functions = [self._translate_function(proto)]
617620

618621
# TODO: unique_function_domain_version.add((f.domain, 1))
@@ -655,22 +658,15 @@ def visit_graph(graph: onnx.GraphProto) -> None:
655658

656659
def export2python(
657660
model_onnx,
658-
opset=None,
659-
verbose=True,
660-
name=None,
661-
rename=False,
662-
function_name="main",
663-
use_operators=False,
661+
function_name: Optional[str] = None,
662+
rename: bool = False,
663+
use_operators: bool = False,
664664
inline_const: bool = False,
665665
):
666666
"""Exports an ONNX model to the *python* syntax.
667667
668668
Args:
669669
model_onnx: string or ONNX graph
670-
opset: opset to export to (None to select the one from the
671-
graph)
672-
verbose: inserts prints
673-
name: to overwrite onnx name
674670
rename: rename the names to get shorter names
675671
function_name: main function name
676672
use_operators: use Python operators.
@@ -694,9 +690,6 @@ def export2python(
694690
code = export2python(onx)
695691
print(code)
696692
"""
697-
del opset # unused
698-
del verbose # unused
699-
del name # unused
700693
if isinstance(model_onnx, str):
701694
model_onnx = onnx.load(model_onnx)
702695

onnxscript/function_libs/torch_lib/ops/fft.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,18 @@ def _fftn_onnx(
9595
# dimension at the beginning to represent the batch dimension.
9696
transformed = op.Unsqueeze(self, axes=[0])
9797

98-
for dim_ in dims:
99-
if dim_ >= 0:
100-
# Add 1 to account for the batch dimension when counting axes from the left
101-
dim_ = dim_ + 1
102-
transformed = op.DFT(transformed, axis=dim_, inverse=inverse, onesided=onesided)
98+
# Add 1 to account for the batch dimension when counting axes from the left
99+
new_dims = [dim_ + 1 if dim_ >= 0 else dim_ for dim_ in dims]
100+
101+
for dim in new_dims[:-1]:
102+
transformed = op.DFT(transformed, axis=dim, inverse=inverse, onesided=False)
103+
104+
# Torch computers one-sided FFT on the last dimension only.
105+
if onesided:
106+
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=True)
107+
else:
108+
transformed = op.DFT(transformed, axis=new_dims[-1], inverse=inverse, onesided=False)
109+
103110
# Remove the batch dimension
104111
transformed = op.Squeeze(transformed, axes=[0])
105112

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,10 +1404,18 @@ def aten_pad_sequence(
14041404
raise NotImplementedError()
14051405

14061406

1407-
def aten_reflection_pad1d(self: TensorType, padding: INT64) -> TensorType:
1407+
@torch_op("aten::reflection_pad1d")
1408+
def aten_reflection_pad1d(self: TFloat, padding: INT64) -> TFloat:
14081409
"""reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor"""
14091410

1410-
raise NotImplementedError()
1411+
# assert len(padding) == 2
1412+
# Input of padding argument should be [x,y], need change to onnx format [0, x, 0, y]
1413+
start = op.Slice(padding, [0], [1], axes=[0])
1414+
end = op.Slice(padding, [1], [2], axes=[0])
1415+
padding_onnx = op.Concat(
1416+
op.Constant(value_ints=[0]), start, op.Constant(value_ints=[0]), end, axis=0
1417+
)
1418+
return op.Pad(self, padding_onnx, mode="reflect")
14111419

14121420

14131421
def aten_reflection_pad1d_backward(

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -190,21 +190,20 @@ def sample_inputs_convolution(op_info, device, dtype, requires_grad, **kwargs):
190190
)
191191

192192

193-
def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
194-
del self # Unused
193+
def _prepare_data_for_fft_ops(device, dtype, requires_grad=False):
195194
# Adapted from https://github.com/pytorch/pytorch/blob/01069ad4be449f376cf88a56d842b8eb50f6e9b6/torch/testing/_internal/opinfo/core.py#L2448C1-L2541C79
196195
is_fp16_or_chalf = dtype in (torch.complex32, torch.half)
197196
if not is_fp16_or_chalf:
198-
nd_tensor = functools.partial(
197+
oned_tensor = functools.partial(
199198
opinfo_core.make_tensor,
200-
(S, S + 1, S + 2),
199+
(31,),
201200
device=device,
202201
dtype=dtype,
203202
requires_grad=requires_grad,
204203
)
205-
oned_tensor = functools.partial(
204+
nd_tensor = functools.partial(
206205
opinfo_core.make_tensor,
207-
(31,),
206+
(S, S + 1, S + 2),
208207
device=device,
209208
dtype=dtype,
210209
requires_grad=requires_grad,
@@ -214,25 +213,32 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
214213
high = None
215214
shapes = ((2, 8, 9), (33,))
216215

217-
nd_tensor = functools.partial(
216+
oned_tensor = functools.partial(
218217
opinfo_core.make_tensor,
219-
shapes[0],
218+
shapes[1],
220219
device=device,
221220
low=low,
222221
high=high,
223222
dtype=dtype,
224223
requires_grad=requires_grad,
225224
)
226-
oned_tensor = functools.partial(
225+
nd_tensor = functools.partial(
227226
opinfo_core.make_tensor,
228-
shapes[1],
227+
shapes[0],
229228
device=device,
230229
low=low,
231230
high=high,
232231
dtype=dtype,
233232
requires_grad=requires_grad,
234233
)
235234

235+
return oned_tensor, nd_tensor
236+
237+
238+
def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
239+
del self # Unused
240+
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)
241+
236242
for normalization, forward in itertools.product((0, 1, 2), (True, False)):
237243
# 1-D
238244
yield opinfo_core.SampleInput(
@@ -252,6 +258,29 @@ def sample_inputs__fft_c2c(self, device, dtype, requires_grad=False, **_):
252258
)
253259

254260

261+
def sample_inputs__fft_r2c(self, device, dtype, requires_grad=False, **_):
262+
del self # Unused
263+
oned_tensor, nd_tensor = _prepare_data_for_fft_ops(device, dtype, requires_grad)
264+
265+
for normalization, one_sided in itertools.product((0, 1, 2), (True, True)):
266+
# 1-D
267+
yield opinfo_core.SampleInput(
268+
oned_tensor(), dim=(0,), normalization=normalization, onesided=one_sided
269+
)
270+
# N-D
271+
for dim in [
272+
(0,),
273+
(1,),
274+
(2,),
275+
(1, 2),
276+
(0, 1),
277+
(0, 1, 2),
278+
]:
279+
yield opinfo_core.SampleInput(
280+
nd_tensor(), dim=dim, normalization=normalization, onesided=one_sided
281+
)
282+
283+
255284
def sample_inputs_layer_norm(op_info, device, dtype, requires_grad, **kwargs):
256285
del op_info # unused
257286
del kwargs
@@ -1336,6 +1365,25 @@ def sample_inputs__native_batch_norm_legit_no_stats(
13361365
)
13371366

13381367

1368+
def sample_inputs_reflection_pad1d(op_info, device, dtype, requires_grad, **kwargs):
1369+
del op_info
1370+
del kwargs
1371+
1372+
cases: tuple = ( # ignore
1373+
((2, 3), (1, 2)),
1374+
((4, 5), (0, 1)),
1375+
((6, 7), (1, 1)),
1376+
((8, 9), (1, 0)),
1377+
)
1378+
1379+
make_inp = opinfo_core.partial(
1380+
torch.testing.make_tensor, device=device, dtype=dtype, requires_grad=requires_grad
1381+
)
1382+
1383+
for shape, pad in cases:
1384+
yield opinfo_core.SampleInput(make_inp(shape), args=(pad,))
1385+
1386+
13391387
# NOTE: How to create an OpInfo:
13401388
# 1. Create a function that generates sample inputs for the op.
13411389
# This function should yield SampleInputs.
@@ -1358,6 +1406,13 @@ def sample_inputs__native_batch_norm_legit_no_stats(
13581406
sample_inputs_func=sample_inputs__fft_c2c,
13591407
supports_out=False,
13601408
),
1409+
opinfo_core.OpInfo(
1410+
"ops.aten._fft_r2c",
1411+
aten_name="_fft_r2c",
1412+
dtypes=common_dtype.floating_types(),
1413+
sample_inputs_func=sample_inputs__fft_r2c,
1414+
supports_out=False,
1415+
),
13611416
opinfo_core.OpInfo(
13621417
"ops.aten._local_scalar_dense",
13631418
aten_name="_local_scalar_dense",
@@ -1407,6 +1462,13 @@ def sample_inputs__native_batch_norm_legit_no_stats(
14071462
sample_inputs_func=sample_inputs_convolution,
14081463
supports_out=False,
14091464
),
1465+
opinfo_core.OpInfo(
1466+
"ops.aten.reflection_pad1d",
1467+
aten_name="ops.aten.reflection_pad1d",
1468+
dtypes=common_dtype.floating_and_complex_types_and(torch.int64, torch.bfloat16),
1469+
sample_inputs_func=sample_inputs_reflection_pad1d,
1470+
supports_out=False,
1471+
),
14101472
opinfo_core.OpInfo(
14111473
"ops.aten.index.Tensor",
14121474
aten_name="index.Tensor",

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,12 @@ def _where_input_wrangler(
465465
trace_only=True,
466466
complex=True,
467467
),
468+
TorchLibOpInfo(
469+
"ops.aten._fft_r2c", # Custom from extra_opinfo
470+
fft_ops.aten__fft_r2c,
471+
tolerance={torch.float64: (2e-6, 2e-6), torch.float32: (3e-2, 3e-4)},
472+
trace_only=True,
473+
),
468474
TorchLibOpInfo(
469475
"ops.aten._local_scalar_dense",
470476
core_ops.aten__local_scalar_dense,
@@ -1189,6 +1195,13 @@ def _where_input_wrangler(
11891195
matcher=lambda sample: "weight" in sample.kwargs,
11901196
reason="this Aten overload doesn't accept weight as kwargs",
11911197
),
1198+
TorchLibOpInfo(
1199+
"ops.aten.reflection_pad1d",
1200+
nn_ops.aten_reflection_pad1d,
1201+
).xfail(
1202+
dtypes=(torch.int64,),
1203+
reason="Torch not implement reflection_pad1d for int64.",
1204+
),
11921205
TorchLibOpInfo(
11931206
"nn.functional.reflection_pad2d",
11941207
nn_ops.aten_reflection_pad2d,

requirements/lintrunner/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ lintrunner-adapters>=0.8.0
33
# RUFF, RUFF-FIX
44
ruff==0.1.6
55
# MYPY
6-
mypy==1.7.0
6+
mypy==1.7.1
77
types-PyYAML==6.0.12.11
88
# PYLINT
99
pylint==2.17.6

tools/onnx2script.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
6+
"""
7+
onnx2script.py
8+
9+
This module provides a script to convert ONNX model files to Python scripts using the onnxscript library.
10+
11+
Usage:
12+
python onnx2script.py <input_file> [-o output_file] [-v]
13+
14+
Arguments:
15+
input_file: The ONNX model file to convert.
16+
-o, --output: The output file name. If not provided, the output will be named after the input file with a .py extension.
17+
-v, --verbose: Enables verbose mode. This suppresses the use of overloaded operators and inline constants.
18+
19+
Example:
20+
python onnx2script.py model.onnx -o model.py -v
21+
"""
22+
23+
import argparse
24+
import os
25+
from typing import Optional
26+
27+
import onnx
28+
29+
import onnxscript
30+
31+
32+
def convert2script(
33+
input_file_name: str, output_file_name: Optional[str], verbose: bool
34+
) -> None:
35+
model = onnx.load(input_file_name, load_external_data=False)
36+
python_code = onnxscript.proto2python(
37+
model, use_operators=not verbose, inline_const=not verbose
38+
)
39+
40+
# If output file name is not provided, use the input file name with .py extension
41+
if output_file_name is None:
42+
base_name = os.path.splitext(input_file_name)[0] # Remove extension
43+
output_file_name = base_name + ".py"
44+
45+
with open(output_file_name, "w", encoding="utf-8") as f:
46+
f.write(python_code)
47+
48+
49+
if __name__ == "__main__":
50+
parser = argparse.ArgumentParser(description="Convert ONNX model file to onnxscript file")
51+
parser.add_argument("input", help="ONNX model file to convert")
52+
parser.add_argument("-o", "--output", help="Output file name")
53+
parser.add_argument(
54+
"-v",
55+
"--verbose",
56+
action="store_true",
57+
help="Verbose mode, suppresses use of overloaded operators and inline constants",
58+
default=False,
59+
)
60+
61+
args = parser.parse_args()
62+
convert2script(args.input, args.output, args.verbose)

0 commit comments

Comments
 (0)