Skip to content

Commit 6cfa67c

Browse files
committed
Update on "Auto generate OpSchema for functions | feat(op_schema)"
This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)} ) ``` Fixes #476 [ghstack-poisoned]
2 parents b61543a + 2c6be92 commit 6cfa67c

File tree

2 files changed

+30
-14
lines changed

2 files changed

+30
-14
lines changed

onnxscript/type_annotation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@
4444

4545
_LIST_CONSTRUCTORS = frozenset([list, typing.List, typing.Sequence, collections.abc.Sequence])
4646

47+
ALL_TYPE_STRINGS = (
48+
"tensor(bfloat16)",
49+
"tensor(bool)",
50+
"tensor(double)",
51+
"tensor(float)",
52+
"tensor(float16)",
53+
"tensor(int16)",
54+
"tensor(int32)",
55+
"tensor(int64)",
56+
"tensor(int8)",
57+
"tensor(string)",
58+
"tensor(uint16)",
59+
"tensor(uint32)",
60+
"tensor(uint64)",
61+
"tensor(uint8)",
62+
)
63+
4764

4865
def _remove_annotation(typeinfo: TypeAnnotationValue) -> TypeAnnotationValue:
4966
"""Remove Annotated wrapper if present, otherwise return typeinfo as is."""

onnxscript/type_annotation_test.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6-
from typing import Optional, Sequence, TypeVar, Union
76
import unittest
7+
from typing import Any, Optional, Sequence, TypeVar, Union
88

99
import parameterized
1010

11-
import onnxscript.testing
1211
import onnxscript
13-
from onnxscript import script
12+
import onnxscript.testing
13+
from onnxscript import FLOAT, INT64, script, type_annotation
1414
from onnxscript.onnx_opset import opset15 as op
15-
from onnxscript import FLOAT, INT64
1615
from onnxscript.tests.common import testutils
1716

18-
from onnxscript import type_annotation
19-
2017

2118
class TypeAnnotationTest(testutils.TestBase):
2219
def test_type_annotation(self):
@@ -99,7 +96,7 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
9996
_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT])
10097

10198

102-
class UtilityFunctionsTest(unittest.TestCase):
99+
class TypeConversionFunctionsTest(unittest.TestCase):
103100
@parameterized.parameterized.expand(
104101
[
105102
(
@@ -121,10 +118,12 @@ class UtilityFunctionsTest(unittest.TestCase):
121118
(
122119
"optional_tensor_type_all",
123120
Optional[onnxscript.onnx_types.TensorType],
124-
type_annotation.ALL_TYPE_STRINGS
125-
+ [
126-
f"optional({tensor_type})"
127-
for tensor_type in type_annotation.ALL_TYPE_STRINGS
121+
[
122+
*type_annotation.ALL_TYPE_STRINGS,
123+
*[
124+
f"optional({tensor_type})"
125+
for tensor_type in type_annotation.ALL_TYPE_STRINGS
126+
],
128127
],
129128
),
130129
(
@@ -214,7 +213,7 @@ class UtilityFunctionsTest(unittest.TestCase):
214213
),
215214
]
216215
)
217-
def test_pytype_to_input_strings(self, _, pytype: Any, expected):
216+
def test_pytype_to_input_strings(self, _, pytype: Any, expected: list[str]):
218217
self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected)
219218

220219
@parameterized.parameterized.expand(
@@ -231,15 +230,15 @@ def test_pytype_to_input_strings(self, _, pytype: Any, expected):
231230
Sequence[_TestTypeVarOneBound],
232231
"Sequence_TestTypeVarOneBound",
233232
),
234-
("normal_type", INT64, "None"),
233+
("normal_type", INT64, None),
235234
("union_type", Union[INT64, FLOAT], None),
236235
("optional_type", Optional[INT64], None),
237236
("sequence_type", Sequence[INT64], None),
238237
("optional_sequence_type", Optional[Sequence[INT64]], None),
239238
("optional_union_type", Optional[Union[INT64, FLOAT]], None),
240239
]
241240
)
242-
def get_type_constraint_name(self, _: str, pytype, expected):
241+
def get_type_constraint_name(self, _: str, pytype: Any, expected: Optional[str]):
243242
self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected)
244243

245244

0 commit comments

Comments
 (0)