Skip to content

Commit 2a73b15

Browse files
authored
Auto generate OpSchema for functions | feat(op_schema) (#626)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #692 * #674 * __->__ #626 * #684 This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)} ) ``` Fixes #476
1 parent d3ce597 commit 2a73b15

File tree

3 files changed

+136
-5
lines changed

3 files changed

+136
-5
lines changed

onnxscript/autocast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def cast(x, typeinfo) -> tensor.Tensor:
8686
return cast_inputs(get_type_info, cast, op_schema, *args)
8787

8888

89-
def static_cast_inputs(converter, op_schema: OpSchema, *args):
89+
def static_cast_inputs(converter, op_schema: Optional[OpSchema], *args):
9090
"""Used for autocast during script-translation."""
9191
if op_schema is None:
9292
return args

onnxscript/tests/function_libs/torch_lib/ops_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ def test_script_function_passes_checker(self, _, func_with_wrangler):
106106
function_proto = func.to_function_proto()
107107
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
108108

109+
@parameterized.parameterized.expand(
110+
list(ops_test_data.OPINFO_FUNCTION_MAPPING_SCRIPTED.items())
111+
)
112+
@unittest.skipIf(
113+
version_utils.onnx_older_than("1.15"),
114+
"OpSchema is not writable before ONNX 1.15",
115+
)
116+
def test_script_function_has_op_schema(self, _, func_with_wrangler):
117+
func, _ = _split_function_and_wrangler(func_with_wrangler)
118+
schema = func.opschema
119+
self.assertIsNotNone(schema)
120+
self.assertEqual(schema.name, func.name)
121+
109122

110123
def run_test_output_match(
111124
test_suite: unittest.TestCase,

onnxscript/values.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import logging
99
import types
1010
from enum import IntFlag
11-
from typing import Any, Optional, Sequence, _GenericAlias # type: ignore[attr-defined]
11+
from typing import _GenericAlias # type: ignore[attr-defined]
12+
from typing import Any, Optional, Sequence
1213

1314
import onnx
1415
import onnx.defs
1516

16-
from onnxscript import irbuilder, sourceinfo
17+
from onnxscript import irbuilder, sourceinfo, type_annotation
18+
from onnxscript._internal import version_utils
1719

1820
_ATTRIBUTE_TYPE_TO_PYTHON_TYPE = {
1921
onnx.defs.OpSchema.AttrType.FLOAT: float,
@@ -34,6 +36,7 @@
3436

3537
# A special value to indicate that the default value is not specified
3638
_EmptyDefault = object()
39+
_ONNX_OP_SCHEMA_WRITABLE = not version_utils.onnx_older_than("1.14")
3740

3841

3942
class Opset:
@@ -173,7 +176,7 @@ def __init__(
173176
) -> None:
174177
self.opset = opset
175178
self.opname = opname
176-
self.opschema = opschema
179+
self._opschema = opschema
177180
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
178181

179182
def __call__(self, *args, **kwargs):
@@ -190,9 +193,13 @@ def __call__(self, *args, **kwargs):
190193
def is_single_op(self) -> bool:
191194
return isinstance(self.opname, str)
192195

196+
@property
197+
def opschema(self) -> Optional[onnx.defs.OpSchema]:
198+
return self._opschema
199+
193200
def get_schema(self) -> Optional[onnx.defs.OpSchema]:
194201
"""Returns the ONNX OpSchema for this op."""
195-
if self.opschema:
202+
if self.opschema is not None:
196203
return self.opschema
197204
return self.opset[self.opname]
198205

@@ -249,6 +256,100 @@ class OnnxClosure:
249256
function: Any
250257

251258

259+
@dataclasses.dataclass
260+
class TypeConstraint:
261+
"""Represents a type constraint for an ONNX op.
262+
263+
Attributes:
264+
name: The name of the type constraint.
265+
allowed_types: The allowed types for the type constraint.
266+
"""
267+
268+
name: str
269+
allowed_types: list[str]
270+
description: str = ""
271+
272+
def as_tuple(self) -> tuple[str, list[str], str]:
273+
"""Returns the type constraint as a tuple."""
274+
return (self.name, self.allowed_types, self.description)
275+
276+
277+
def op_schema_from_function_ir(
278+
function_ir: irbuilder.IRFunction, opset: Opset
279+
) -> onnx.defs.OpSchema:
280+
"""Construct an ONNX OpSchema from an IRFunction."""
281+
282+
# Find all distinct types in the inputs and outputs
283+
distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union(
284+
{arg.typeinfo for arg in function_ir.outputs}
285+
)
286+
# Create a mapping from type to a unique name
287+
type_to_constraint = {}
288+
for i, type_ in enumerate(distinct_types):
289+
name = f"T{i}"
290+
type_to_constraint[type_] = TypeConstraint(
291+
name=type_annotation.get_type_constraint_name(type_) or name,
292+
allowed_types=type_annotation.pytype_to_type_strings(type_),
293+
)
294+
295+
formal_inputs = [
296+
onnx.defs.OpSchema.FormalParameter(
297+
arg.name,
298+
type_to_constraint[arg.typeinfo].name,
299+
param_option=(
300+
onnx.defs.OpSchema.FormalParameterOption.Optional
301+
if type_annotation.is_optional(arg.typeinfo)
302+
else onnx.defs.OpSchema.FormalParameterOption.Single
303+
),
304+
# TODO(justinchu): Check this is_homogeneous thing
305+
is_homogeneous=True,
306+
)
307+
for arg in function_ir.inputs
308+
]
309+
formal_outputs = [
310+
onnx.defs.OpSchema.FormalParameter(
311+
arg.name,
312+
type_to_constraint[arg.typeinfo].name,
313+
param_option=(
314+
onnx.defs.OpSchema.FormalParameterOption.Optional
315+
if type_annotation.is_optional(arg.typeinfo)
316+
else onnx.defs.OpSchema.FormalParameterOption.Single
317+
),
318+
# TODO(justinchu): Check this is_homogeneous thing
319+
is_homogeneous=True,
320+
)
321+
for arg in function_ir.outputs
322+
]
323+
324+
return onnx.defs.OpSchema(
325+
function_ir.name,
326+
opset.domain,
327+
since_version=opset.version,
328+
doc=function_ir.docstring,
329+
inputs=formal_inputs,
330+
outputs=formal_outputs,
331+
type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()],
332+
attributes=[
333+
*[
334+
onnx.defs.OpSchema.Attribute(
335+
attr.name,
336+
type=onnx.defs.OpSchema.AttrType(attr.type),
337+
)
338+
for attr in function_ir.attrs
339+
if not attr.has_default
340+
],
341+
*[
342+
onnx.defs.OpSchema.Attribute(
343+
attr.name,
344+
default_value=attr.attr_proto,
345+
)
346+
for attr in function_ir.attrs
347+
if attr.has_default
348+
],
349+
],
350+
)
351+
352+
252353
class OnnxFunction(Op):
253354
"""Represents an ONNX op for which a function-body has been defined in onnxscript.
254355
@@ -276,12 +377,26 @@ def __init__(
276377
self.source = source
277378
self.kwargs = kwargs
278379
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
380+
self._opschema: Optional[onnx.defs.OpSchema] = None
279381

280382
@property
281383
def name(self):
282384
"""Returns the function name."""
283385
return self.opname
284386

387+
@property
388+
def opschema(self) -> Optional[onnx.defs.OpSchema]:
389+
"""Construct an OpSchema from function_ir."""
390+
if self._opschema is not None:
391+
return self._opschema
392+
393+
if not _ONNX_OP_SCHEMA_WRITABLE:
394+
return None
395+
396+
self._opschema = op_schema_from_function_ir(self.function_ir, self.opset)
397+
398+
return self._opschema
399+
285400
def __getitem__(self, instance):
286401
"""Returns a lambda to evaluate function using given evaluator instance.
287402
@@ -311,6 +426,9 @@ def param_schemas(self) -> tuple[ParamSchema, ...]:
311426
if self._param_schemas is not None:
312427
return self._param_schemas
313428

429+
# NOTE: We generate the parameter schemas from the function_ir instead
430+
# of relying on the auto generated OpSchema because we need to preserve the keyword
431+
# argument order from the Python function definition, which is lost in OpSchema.
314432
function_ir = self.function_ir
315433
# The first len(func_ir.inputs) arguments are onnx inputs
316434
inputs = function_ir.inputs

0 commit comments

Comments
 (0)