Skip to content

Commit 2d158ac

Browse files
committed
Update base for Update on "Auto generate OpSchema for functions | feat(op_schema)"
This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)} ) ``` Fixes #476 [ghstack-poisoned]
2 parents 1d4a0b4 + c299252 commit 2d158ac

File tree

11 files changed

+73
-150
lines changed

11 files changed

+73
-150
lines changed

.github/workflows/release.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ jobs:
6767
path: dist
6868
- name: Install wheel
6969
run: |
70-
python -m pip install dist/*.whl
70+
python -m pip install dist/*.whl --no-deps
7171
- name: Run tests
7272
run: |
7373
python -m pytest -v -n auto

.lintrunner.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Configuration for lintrunner https://github.com/suo/lintrunner
2+
merge_base_with = 'main'
23

34
[[linter]]
45
code = 'RUFF'

onnxscript/backend/onnx_export.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -356,12 +356,7 @@ def _python_make_node(self, onnx_node, opsets, indent=0):
356356
if node.op_type == "Scan":
357357
return self._python_make_node_scan(node, opsets, indent=indent)
358358
raise RuntimeError(f"Unable to export node type {node.op_type!r} into python.")
359-
if any(
360-
map(
361-
lambda att: hasattr(att, "g") and att.g and att.g.ByteSize() > 0,
362-
node.attribute,
363-
)
364-
):
359+
if any(hasattr(att, "g") and att.g and att.g.ByteSize() > 0 for att in node.attribute):
365360
raise RuntimeError(f"Unable to export node type {node.op_type!r} into python.")
366361
ops = {
367362
"Add": "+",
@@ -438,7 +433,7 @@ def export_template(
438433
if hasattr(model_onnx, "functions"):
439434
for f in model_onnx.functions:
440435
unique_function_domain_version.add((f.domain, 1))
441-
unique_function_domain_version_sorted = list(sorted(unique_function_domain_version))
436+
unique_function_domain_version_sorted = sorted(unique_function_domain_version)
442437

443438
if rename:
444439
variable_names: dict[str, str] = {}
@@ -486,7 +481,7 @@ def rename_variable(name):
486481
ts = _translate_type(t.type)
487482
its = ts.split("[", maxsplit=1)[0]
488483
unique_types.add(its)
489-
context["unique_types"] = list(sorted(unique_types))
484+
context["unique_types"] = sorted(unique_types)
490485

491486
# functions
492487
functions = []

onnxscript/converter_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def validate_save(
8282
ort.InferenceSession(model.SerializeToString())
8383
except (Fail, InvalidGraph, InvalidArgument) as e:
8484
raise AssertionError(
85-
f"onnxruntime cannot load function " f"{f.name}\n--\n{model}"
85+
f"onnxruntime cannot load function {f.name}\n--\n{model}"
8686
) from e
8787
if shape_inference:
8888
model = onnx.shape_inference.infer_shapes(model)
@@ -423,7 +423,7 @@ def check_function(x, name, expected, eager=True):
423423
y = session.run(None, {"A": x})[0]
424424
except Exception as e:
425425
raise AssertionError(
426-
f"Unable to run ONNX for function {name!r} " f"due to {e!r}\n{onx}."
426+
f"Unable to run ONNX for function {name!r} due to {e!r}\n{onx}."
427427
) from e
428428
self.assertEqual(y.tolist(), expected)
429429
f = getattr(getitem, name)
@@ -477,7 +477,7 @@ def check_function(x, name, expected, eager=True):
477477
y = session.run(None, {"A": x})[0]
478478
except Exception as e:
479479
raise AssertionError(
480-
f"Unable to run ONNX for function {name!r} " f"due to {e!r}\n{onx}."
480+
f"Unable to run ONNX for function {name!r} due to {e!r}\n{onx}."
481481
) from e
482482
self.assertEqual(y.tolist(), expected)
483483
f = getattr(getitem39, name)
@@ -528,7 +528,7 @@ def check_run(self, onnxfn, inputs, expected_output):
528528
model = onnxfn.to_model_proto()
529529
session = ort.InferenceSession(model.SerializeToString())
530530
input_names = [x.name for x in model.graph.input]
531-
input_dict = {x: value for (x, value) in zip(input_names, inputs)}
531+
input_dict = dict(zip(input_names, inputs))
532532
output = session.run(None, input_dict)[0]
533533
np.testing.assert_equal(output, expected_output)
534534

onnxscript/function_libs/torch_aten/ops/core.py

Lines changed: 47 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -5365,17 +5365,58 @@ def aten_slice_copy(
53655365
raise NotImplementedError()
53665366

53675367

5368+
@torch_op("aten::slice_scatter", trace_only=True)
53685369
def aten_slice_scatter(
5369-
self: TensorType,
5370-
src: TensorType,
5370+
self: TTensor,
5371+
src: TTensor,
53715372
dim: int = 0,
53725373
start: Optional[INT64] = None,
53735374
end: Optional[INT64] = None,
53745375
step: INT64 = 1,
5375-
) -> TensorType:
5376+
) -> TTensor:
53765377
"""slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor"""
53775378

5378-
raise NotImplementedError()
5379+
# Although 'start' and 'end' can be None in signature, but actually 'start' must be specified
5380+
# Assert(start is not None)
5381+
# And, 'end' also must be specified, and end-start must be equal to the size of 'src'
5382+
# Assert(end-start == shape(src) > 0)
5383+
# Try torch sample to get more information:
5384+
# https://pytorch.org/docs/master/generated/torch.slice_scatter.html?highlight=slice_scatter#torch.slice_scatter
5385+
# e.g. if dim=2, shape=5, permute will be [0,1]+[4]+[2,3]=[0,1,4,2,3]
5386+
last = len(src.shape)
5387+
perm = list(range(0, last))
5388+
perm.insert(dim, perm.pop(-1))
5389+
return _aten_slice_scatter_onnx(self, src, start, end, step, dim, perm)
5390+
5391+
5392+
@torch_op("aten::slice_scatter", private=True)
5393+
def _aten_slice_scatter_onnx(
5394+
self: TTensor,
5395+
src: TTensor,
5396+
start: INT64,
5397+
end: INT64,
5398+
step: INT64,
5399+
dim: int,
5400+
perm: Sequence[int],
5401+
) -> TTensor:
5402+
neg_1 = op.Constant(value_ints=[-1])
5403+
# Get shapes expcept specifide dim
5404+
# e.g. if dim=2, shape=(2,3,5,7), shape_expand will be (2,3,7,1)
5405+
src_shape = op.Shape(src)
5406+
last_dim = op.Reshape(op.Size(src_shape), neg_1)
5407+
dim_tensor = op.Reshape(op.Constant(value_int=dim), neg_1)
5408+
shape_before_dim = op.Slice(src_shape, op.Constant(value_ints=[0]), dim_tensor)
5409+
shape_after_dim = op.Slice(src_shape, op.Add(dim_tensor, 1), last_dim)
5410+
shape_expand = op.Concat(
5411+
shape_before_dim, shape_after_dim, op.Constant(value_ints=[1]), axis=0
5412+
)
5413+
# Generate index but not finalized, need to do transpose later
5414+
# e.g. [[0,1,2],[0,1,2],[0,1,2]...,[0,1,2]], total count = 2x3x7
5415+
index_base = op.Range(start, end, step) # e.g. [0,1,2]
5416+
index_expand = op.Expand(index_base, shape_expand)
5417+
indices = op.Transpose(index_expand, perm=perm)
5418+
5419+
return op.ScatterElements(self, indices, src, axis=dim)
53795420

53805421

53815422
def aten_slogdet(self: TensorType) -> tuple[TensorType, TensorType]:
@@ -6043,96 +6084,10 @@ def aten_var(self: TensorType, unbiased: bool = True) -> TensorType:
60436084
raise NotImplementedError()
60446085

60456086

6046-
@torch_op("aten::var_mean", trace_only=True)
6047-
def aten_var_mean(self: TReal, unbiased: bool = True) -> Tuple[TReal, TReal]:
6087+
def aten_var_mean(self: TensorType, unbiased: bool = True) -> tuple[TensorType, TensorType]:
60486088
"""var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)"""
60496089

6050-
# Assume bool(True) and int(1) are same in ONNX, so pass "unbiased" directly as "correction"
6051-
# If not this case, should be explicitly set correction value according to unbiased value
6052-
return _aten_var_mean_onnx(self, correction=int(unbiased), keepdim=False)
6053-
6054-
6055-
@torch_op("aten::var_mean", overload=True, trace_only=True)
6056-
def aten_var_mean_dim(
6057-
self: TReal, dim: Optional[int], unbiased: bool = True, keepdim: bool = False
6058-
) -> Tuple[TReal, TReal]:
6059-
"""var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)"""
6060-
6061-
# Although dim is Optional in signature, but we assume it must has value for this overload
6062-
# Assert(dim is not None)
6063-
if isinstance(dim, Tuple):
6064-
dim_tensor = op.Constant(value_ints=dim)
6065-
else:
6066-
dim_tensor = op.Constant(value_int=dim)
6067-
return _aten_var_mean_dim_onnx(self, dim_tensor, correction=int(unbiased), keepdim=keepdim)
6068-
6069-
6070-
@torch_op("aten::var_mean", overload=True, trace_only=True)
6071-
def aten_var_mean_correction(
6072-
self: TReal,
6073-
dim: Optional[int] = None,
6074-
correction: Optional[int] = None,
6075-
keepdim: bool = False,
6076-
) -> Tuple[TReal, TReal]:
6077-
"""var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)"""
6078-
6079-
if correction is None:
6080-
correction = 1
6081-
6082-
if dim is None:
6083-
var, mean = _aten_var_mean_onnx(self, correction, keepdim)
6084-
else:
6085-
if isinstance(dim, Tuple):
6086-
dim_tensor = op.Constant(value_ints=dim)
6087-
else:
6088-
dim_tensor = op.Constant(value_int=dim)
6089-
var, mean = _aten_var_mean_dim_onnx(self, dim_tensor, correction, keepdim)
6090-
return var, mean
6091-
6092-
6093-
@torch_op("aten::var_mean", private=True)
6094-
def _aten_var_mean_onnx(
6095-
self: TReal, correction: int = 1, keepdim: bool = False
6096-
) -> Tuple[TReal, TReal]:
6097-
# Compute mean and var
6098-
mean = op.ReduceMean(self, keepdims=keepdim)
6099-
sub_mean = op.Sub(self, mean)
6100-
sqr_mean = op.Mul(sub_mean, sub_mean)
6101-
var = op.ReduceMean(sqr_mean, keepdims=keepdim)
6102-
# Adjust var according to correction value
6103-
if correction != 0:
6104-
self_shape = op.Shape(self)
6105-
numel_int = op.ReduceProd(self_shape, keepdims=0)
6106-
numel_float = op.Cast(numel_int, to=FLOAT.dtype)
6107-
mul = op.Mul(var, numel_float)
6108-
sub = op.Sub(numel_int, correction)
6109-
var = op.Div(mul, op.Cast(sub, to=FLOAT.dtype))
6110-
6111-
return var, mean
6112-
6113-
6114-
@torch_op("aten::var_mean", private=True)
6115-
def _aten_var_mean_dim_onnx(
6116-
self: TReal, dim: INT64, correction: int, keepdim: bool = False
6117-
) -> Tuple[TReal, TReal]:
6118-
if op.Size(op.Shape(dim)) == 0:
6119-
dim = op.Unsqueeze(dim, axes=0)
6120-
# Computer mean and var
6121-
mean = op.ReduceMean(self, dim, keepdims=keepdim)
6122-
sub_mean = op.Sub(self, op.ReduceMean(self, dim, keepdims=1))
6123-
sqr_mean = op.Mul(sub_mean, sub_mean)
6124-
var = op.ReduceMean(sqr_mean, dim, keepdims=keepdim)
6125-
# Adjust var according to correction value
6126-
if correction != 0:
6127-
self_shape = op.Shape(self)
6128-
dim_size = op.Gather(self_shape, dim, axis=0)
6129-
numel_int = op.ReduceProd(dim_size, keepdims=0)
6130-
numel_float = op.Cast(numel_int, to=FLOAT.dtype)
6131-
mul = op.Mul(var, numel_float)
6132-
sub = op.Sub(numel_int, correction)
6133-
var = op.Div(mul, op.Cast(sub, to=FLOAT.dtype))
6134-
6135-
return var, mean
6090+
raise NotImplementedError()
61366091

61376092

61386093
def aten_vdot(self: TensorType, other: TensorType) -> TensorType:

onnxscript/function_libs/torch_aten/ops/nn.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,11 @@ def aten_hardswish_backward(grad_output: TensorType, self: TensorType) -> Tensor
529529
raise NotImplementedError()
530530

531531

532-
def aten_hardtanh(self: TensorType, min_val: float = -1.0, max_val: float = 1.0) -> TensorType:
532+
@torch_op("aten::hardtanh")
533+
def aten_hardtanh(self: TReal, min_val: float = -1.0, max_val: float = 1.0) -> TReal:
533534
"""hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor"""
534535

535-
raise NotImplementedError()
536+
return op.Clip(self, min_val, max_val)
536537

537538

538539
def aten_hardtanh_backward(

onnxscript/tests/eager_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def _fft(x, fft_length, axis=-1):
2222
tr = np.transpose(merged, list(perm))
2323
if tr.shape[-1] != 2:
2424
raise AssertionError(
25-
f"Unexpected shape {tr.shape}, x.shape={x.shape} " f"fft_length={fft_length}."
25+
f"Unexpected shape {tr.shape}, x.shape={x.shape} fft_length={fft_length}."
2626
)
2727
return tr
2828

@@ -48,7 +48,7 @@ def _ifft(x, fft_length, axis=-1):
4848
tr = np.transpose(merged, list(perm))
4949
if tr.shape[-1] != 2:
5050
raise AssertionError(
51-
f"Unexpected shape {tr.shape}, x.shape={x.shape} " f"fft_length={fft_length}."
51+
f"Unexpected shape {tr.shape}, x.shape={x.shape} fft_length={fft_length}."
5252
)
5353
return tr
5454

onnxscript/tests/function_libs/torch_aten/ops_correctness_test.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,7 @@ def _where_input_wrangler(
610610
"nn.functional.dropout": (core_ops.aten_dropout, _dropout_input_wrangler),
611611
"nn.functional.elu": nn_ops.aten_elu,
612612
"nn.functional.embedding": (core_ops.aten_embedding, _embedding_input_wrangler),
613+
"nn.functional.hardtanh": nn_ops.aten_hardtanh,
613614
"nn.functional.leaky_relu": nn_ops.aten_leaky_relu,
614615
"nn.functional.logsigmoid": nn_ops.aten_log_sigmoid,
615616
"nn.functional.nll_loss_weight": (nn_ops.aten_nll_loss_weight, _nll_loss_input_wrangler),
@@ -731,12 +732,10 @@ def _where_input_wrangler(
731732
),
732733
"ones_like": core_ops.aten_ones_like,
733734
"scatter_reduce": (core_ops.aten_scatter_reduce, _scatter_reduce_input_wrangler),
735+
"slice_scatter": core_ops.aten_slice_scatter,
734736
"slice": core_ops.aten_slice,
735737
"sum": (core_ops.aten_sum_dim_IntList, _sum_input_wrangler),
736738
"transpose": core_ops.aten_transpose,
737-
"var_mean": core_ops.aten_var_mean,
738-
"var_mean_dim": core_ops.aten_var_mean_dim,
739-
"var_mean_correction": core_ops.aten_var_mean_correction,
740739
"zeros_like": core_ops.aten_zeros_like,
741740
}
742741

@@ -1194,27 +1193,6 @@ def _where_input_wrangler(
11941193
matcher=lambda sample: not (len(sample.args) > 0 and isinstance(sample.args[0], int)),
11951194
reason="this Aten overload only support one tensor as input and one int as args by design",
11961195
),
1197-
skip(
1198-
"var_mean",
1199-
# kwargs is empty
1200-
matcher=lambda sample: len(sample.kwargs) > 0,
1201-
reason="this Aten overload only support input[0]=tensor and input[1]=bool as input without any kwargs",
1202-
),
1203-
skip(
1204-
"var_mean_dim",
1205-
# kwargs["dim"] must exist, kwargs["correction"] must not exist
1206-
matcher=lambda sample: not (
1207-
sample.kwargs.get("dim", None) is not None
1208-
and sample.kwargs.get("correction", None) is None
1209-
),
1210-
reason="this Aten overload only support with 'dim' argument and without 'correction' argument",
1211-
),
1212-
skip(
1213-
"var_mean_correction",
1214-
# Don't accept input[1]=bool and 'correction' must be in kwargs
1215-
matcher=lambda sample: len(sample.args) > 0 or "correction" not in sample.kwargs,
1216-
reason="this Aten overload only support when correction attribute exists",
1217-
),
12181196
skip(
12191197
"unflatten",
12201198
matcher=lambda sample: any(dim == 0 for dim in sample.input.shape),
@@ -1284,15 +1262,6 @@ def _where_input_wrangler(
12841262

12851263
duplicate_opinfo(OPS_DB, "squeeze", ("squeeze_dim",))
12861264

1287-
duplicate_opinfo(
1288-
OPS_DB,
1289-
"var_mean",
1290-
(
1291-
"var_mean_dim",
1292-
"var_mean_correction",
1293-
),
1294-
)
1295-
12961265

12971266
# END OF SECTION TO MODIFY #####################################################
12981267

opgen/onnx_opset_builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def _make_opset_module(self, domain: str, version: int):
123123
cg.FunctionDef(
124124
"__new__",
125125
cg.Arg("cls"),
126-
body=cg.ThunkStmt(
127-
f"return Opset.__new__(cls, " f"{domain!r}, {version!r})"
128-
),
126+
body=cg.ThunkStmt(f"return Opset.__new__(cls, {domain!r}, {version!r})"),
129127
),
130128
cg.FunctionDef(
131129
"__init__", cg.Arg("self"), body=cg.ThunkStmt("super().__init__()")

pyproject.toml

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,23 @@ convention = "google"
106106
[tool.ruff]
107107
target-version = "py38"
108108
select = [
109+
"B", # flake8-bugbear
110+
"C4", # flake8-comprehensions
109111
"D", # pydocstyle
110112
"E", # pycodestyle
111113
"F", # Pyflakes
112-
"W", # pycodestyle
113-
"B", # flake8-bugbear
114+
"G", # flake8-logging-format
115+
"ISC", # flake8-implicit-str-concat
114116
"N", # pep8-naming
115117
"NPY", # modern numpy
116-
"YTT", # flake8-2020
117118
"RUF", # Ruff-specific rules
118-
"UP", # pyupgrade
119119
"TID252", # Disallow relative imports
120+
"UP", # pyupgrade
121+
"W", # pycodestyle
122+
"YTT", # flake8-2020
120123
]
121124
ignore = [
125+
"C408", # Sometimes it is preferable when we construct kwargs
122126
"D1", # D1 is for missing docstrings, which is not yet enforced.
123127
"D202", # D202 Too strict. "No blank lines allowed after function docstring"
124128
"D205", # D205 Too strict. "1 blank line required between summary line and description"

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ pyyaml
2727
torch>=1.13
2828

2929
# Lint
30-
lintrunner
30+
lintrunner>=0.10.7
3131
lintrunner_adapters>=0.7.0

0 commit comments

Comments
 (0)