Skip to content

Commit b61543a

Browse files
committed
Update on "Auto generate OpSchema for functions | feat(op_schema)"
This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)} ) ``` Fixes #476 [ghstack-poisoned]
2 parents 19f9484 + 31bb69c commit b61543a

File tree

2 files changed

+144
-15
lines changed

2 files changed

+144
-15
lines changed

onnxscript/type_annotation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,6 @@ def pytype_to_input_strings(pytype: TypeAnnotationValue) -> list[str]:
199199

200200

201201
def get_type_constraint_name(pytype: TypeAnnotationValue) -> Optional[str]:
202-
pass
202+
if isinstance(pytype, TypeVar):
203+
return pytype.__name__
204+
return None

onnxscript/type_annotation_test.py

Lines changed: 141 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6-
from typing import TypeVar, Union
6+
from typing import Optional, Sequence, TypeVar, Union
77
import unittest
88

99
import parameterized
@@ -12,7 +12,7 @@
1212
import onnxscript
1313
from onnxscript import script
1414
from onnxscript.onnx_opset import opset15 as op
15-
from onnxscript.onnx_types import FLOAT
15+
from onnxscript import FLOAT, INT64
1616
from onnxscript.tests.common import testutils
1717

1818
from onnxscript import type_annotation
@@ -93,28 +93,155 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
9393
bool_type_for_attribute, bool_type_for_attribute_txt
9494
)
9595

96-
_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT)
97-
_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64)
98-
_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT])
96+
97+
_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT)
98+
_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64)
99+
_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT])
99100

100101

101102
class UtilityFunctionsTest(unittest.TestCase):
102103
@parameterized.parameterized.expand(
103104
[
104-
("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS),
105-
("tensor_type", onnxscript.INT64, ["tensor(int64)"]),
106-
("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]),
107-
("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]),
108-
("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]),
105+
(
106+
"tensor_type_all",
107+
onnxscript.onnx_types.TensorType,
108+
type_annotation.ALL_TYPE_STRINGS,
109+
),
110+
("tensor_type", INT64, ["tensor(int64)"]),
111+
("tensor_type_union", Union[INT64, FLOAT], ["tensor(int64)", "tensor(float)"]),
112+
("tensor_type_variadic_shape", INT64[...], ["tensor(int64)"]),
113+
("tensor_type_shape", INT64[10], ["tensor(int64)"]),
114+
(
115+
"type_var_constraints",
116+
_TestTypeVarConstraints,
117+
["tensor(int64)", "tensor(float)"],
118+
),
109119
("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]),
110120
("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]),
121+
(
122+
"optional_tensor_type_all",
123+
Optional[onnxscript.onnx_types.TensorType],
124+
type_annotation.ALL_TYPE_STRINGS
125+
+ [
126+
f"optional({tensor_type})"
127+
for tensor_type in type_annotation.ALL_TYPE_STRINGS
128+
],
129+
),
130+
(
131+
"optional_tensor_type",
132+
Optional[INT64],
133+
["tensor(int64)", "optional(tensor(int64))"],
134+
),
135+
(
136+
"optional_tensor_type_union",
137+
Optional[Union[INT64, FLOAT]],
138+
[
139+
"tensor(int64)",
140+
"tensor(float)",
141+
"optional(tensor(int64))",
142+
"optional(tensor(float))",
143+
],
144+
),
145+
(
146+
"optional_tensor_type_variadic_shape",
147+
Optional[INT64[...]],
148+
["tensor(int64)", "optional(tensor(int64))"],
149+
),
150+
(
151+
"optional_tensor_type_shape",
152+
Optional[INT64[10]],
153+
["tensor(int64)", "optional(tensor(int64))"],
154+
),
155+
(
156+
"optional_type_var_constraints",
157+
Optional[_TestTypeVarConstraints],
158+
[
159+
"tensor(int64)",
160+
"tensor(float)",
161+
"optional(tensor(int64))",
162+
"optional(tensor(float))",
163+
],
164+
),
165+
(
166+
"optional_type_bound_one",
167+
Optional[_TestTypeVarOneBound],
168+
["tensor(int64)", "optional(tensor(int64))"],
169+
),
170+
(
171+
"optional_type_bound_two",
172+
Optional[_TestTypeVarTwoBound],
173+
[
174+
"tensor(int64)",
175+
"tensor(float)",
176+
"optional(tensor(int64))",
177+
"optional(tensor(float))",
178+
],
179+
),
180+
(
181+
"sequence_type_all",
182+
Sequence[onnxscript.onnx_types.TensorType],
183+
[
184+
f"sequence({tensor_type})"
185+
for tensor_type in type_annotation.ALL_TYPE_STRINGS
186+
],
187+
),
188+
("sequence_type", Sequence[INT64], ["sequence(tensor(int64))"]),
189+
(
190+
"union_sequence_type",
191+
Union[Sequence[INT64], Sequence[FLOAT]],
192+
["sequence(tensor(int64))", "sequence(tensor(float))"],
193+
),
194+
(
195+
"sequence_type_variadic_shape",
196+
Sequence[INT64[...]],
197+
["sequence(tensor(int64))"],
198+
),
199+
("sequence_type_shape", Sequence[INT64[10]], ["sequence(tensor(int64))"]),
200+
(
201+
"sequence_type_var_constraints",
202+
Sequence[_TestTypeVarConstraints],
203+
["sequence(tensor(int64))", "sequence(tensor(float))"],
204+
),
205+
(
206+
"sequence_type_bound_one",
207+
Sequence[_TestTypeVarOneBound],
208+
["sequence(tensor(int64))"],
209+
),
210+
(
211+
"sequence_type_bound_two",
212+
Sequence[_TestTypeVarTwoBound],
213+
["sequence(tensor(int64))", "sequence(tensor(float))"],
214+
),
215+
]
216+
)
217+
def test_pytype_to_input_strings(self, _, pytype: Any, expected):
218+
self.assertEqual(type_annotation.pytype_to_input_strings(pytype), expected)
219+
220+
@parameterized.parameterized.expand(
221+
[
222+
("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"),
223+
("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"),
224+
(
225+
"optional_type_var",
226+
Optional[_TestTypeVarOneBound],
227+
"Optional_TestTypeVarOneBound",
228+
),
229+
(
230+
"sequence_type_var",
231+
Sequence[_TestTypeVarOneBound],
232+
"Sequence_TestTypeVarOneBound",
233+
),
234+
("normal_type", INT64, "None"),
235+
("union_type", Union[INT64, FLOAT], None),
236+
("optional_type", Optional[INT64], None),
237+
("sequence_type", Sequence[INT64], None),
238+
("optional_sequence_type", Optional[Sequence[INT64]], None),
239+
("optional_union_type", Optional[Union[INT64, FLOAT]], None),
111240
]
112241
)
113-
def test_pytype_to_input_strings(self, _, pytype: Any, expected)
114-
pass
242+
def get_type_constraint_name(self, _: str, pytype, expected):
243+
self.assertEqual(type_annotation.get_type_constraint_name(pytype), expected)
115244

116-
def get_type_constraint_name(self, _: str, typevar, expected):
117-
pass
118245

119246
if __name__ == "__main__":
120247
unittest.main()

0 commit comments

Comments
 (0)