Skip to content

Commit 19f9484

Browse files
committed
Update on "Auto generate OpSchema for functions | feat(op_schema)"
This change adds the capability to auto generate `OpSchema`. ### Changes - Implement the `opschema` property in `OnnxFunction` - Test on all torch_lib functions ### Next PR Support trace_only functions ## Example ```python from onnxscript.function_libs.torch_aten.ops import core, nn print("core.aten_abs.opschema: ", core.aten_abs.opschema) print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema) ``` Results ``` core.aten_abs.opschema: OpSchema( name='aten_abs', domain='onnxscript.atenlib', since_version=1, doc='abs(Tensor self) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={} ) nn.aten_cross_entropy_loss.opschema: OpSchema( name='aten_cross_entropy_loss', domain='onnxscript.atenlib', since_version=1, doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor', type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')], inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)], attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index" i: -100 type: INT , required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing" f: 0.0 type: FLOAT , required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction" i: 1 type: INT , required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)} ) ``` Fixes #476 [ghstack-poisoned]
2 parents 28b4a48 + 622b688 commit 19f9484

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

onnxscript/type_annotation_test.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55

6+
from typing import TypeVar, Union
67
import unittest
78

9+
import parameterized
10+
811
import onnxscript.testing
12+
import onnxscript
913
from onnxscript import script
1014
from onnxscript.onnx_opset import opset15 as op
1115
from onnxscript.onnx_types import FLOAT
1216
from onnxscript.tests.common import testutils
1317

18+
from onnxscript import type_annotation
19+
1420

1521
class TypeAnnotationTest(testutils.TestBase):
1622
def test_type_annotation(self):
@@ -87,9 +93,27 @@ def bool_type_for_attribute(self: FLOAT[...], sorted: bool) -> FLOAT[...]:
8793
bool_type_for_attribute, bool_type_for_attribute_txt
8894
)
8995

96+
_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", onnxscript.INT64, onnxscript.FLOAT)
97+
_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=onnxscript.INT64)
98+
_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[onnxscript.INT64, onnxscript.FLOAT])
99+
90100

91101
class UtilityFunctionsTest(unittest.TestCase):
92-
def test_pytype_to_input_strings(self):
102+
@parameterized.parameterized.expand(
103+
[
104+
("tensor_type", onnxscript.onnx_types.TensorType, type_annotation.ALL_TYPE_STRINGS),
105+
("tensor_type", onnxscript.INT64, ["tensor(int64)"]),
106+
("tensor_type_variadic_shape", onnxscript.INT64[...], ["tensor(int64)"]),
107+
("tensor_type_shape", onnxscript.INT64[10], ["tensor(int64)"]),
108+
("type_var_constraints", _TestTypeVarConstraints, ["tensor(int64)"]),
109+
("type_bound_one", _TestTypeVarOneBound, ["tensor(int64)"]),
110+
("type_bound_two", _TestTypeVarTwoBound, ["tensor(int64)", "tensor(float)"]),
111+
]
112+
)
113+
def test_pytype_to_input_strings(self, _, pytype: Any, expected)
114+
pass
115+
116+
def get_type_constraint_name(self, _: str, typevar, expected):
93117
pass
94118

95119
if __name__ == "__main__":

0 commit comments

Comments
 (0)