Skip to content

Auto generate OpSchema for functions #594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 25 commits into from
Closed

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Apr 5, 2023

This change adds the capability to auto generate OpSchema.

Changes

  • Implement the opschema property in OnnxFunction
  • Change attrs in IRFunction to self.attrs: list[IRAttributeValue] so that it contains type information
  • Include typeinfo in add_attr_parameter
  • Moved version_utils to _internal to use it in values

TODO

Test on all torch_lib functions

Example

from onnxscript.function_libs.torch_aten.ops import core, nn


print("core.aten_abs.opschema: ", core.aten_abs.opschema)

print("nn.aten_cross_entropy_loss.opschema: ", nn.aten_cross_entropy_loss.opschema)

Results

core.aten_abs.opschema:  OpSchema(
    name='aten_abs',
    domain='onnxscript.atenlib',
    since_version=1,
    doc='abs(Tensor self) -> Tensor',
    type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TReal', allowed_type_strs=['tensor(float)', 'tensor(int8)', 'tensor(int16)', 'tensor(int32)', 'tensor(int64)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')],
    inputs=[OpSchema.FormalParameter(name='self', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
    outputs=[OpSchema.FormalParameter(name='return_val', type_str='TReal', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
    attributes={}
)
nn.aten_cross_entropy_loss.opschema:  OpSchema(
    name='aten_cross_entropy_loss',
    domain='onnxscript.atenlib',
    since_version=1,
    doc='cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor',
    type_constraints=[OpSchema.TypeConstraintParam(type_param_str='TFloatOrBFloat16', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description=''), OpSchema.TypeConstraintParam(type_param_str='T1', allowed_type_strs=['tensor(float)', 'tensor(float16)', 'tensor(double)', 'tensor(bfloat16)'], description='')],
    inputs=[OpSchema.FormalParameter(name='self', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>), OpSchema.FormalParameter(name='weight', type_str='T1', description='', param_option=<FormalParameterOption.Optional: 1>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
    outputs=[OpSchema.FormalParameter(name='result_10', type_str='TFloatOrBFloat16', description='', param_option=<FormalParameterOption.Single: 0>, is_homogeneous=True, min_arity=1, differentiation_category=<DifferentiationCategory.Unknown: 0>)],
    attributes={'ignore_index': OpSchema.Attribute(name='ignore_index', type=<AttrType.INT: 2>, description='', default_value=name: "ignore_index"
i: -100
type: INT
, required=False), 'label_smoothing': OpSchema.Attribute(name='label_smoothing', type=<AttrType.FLOAT: 1>, description='', default_value=name: "label_smoothing"
f: 0.0
type: FLOAT
, required=False), 'reduction': OpSchema.Attribute(name='reduction', type=<AttrType.INT: 2>, description='', default_value=name: "reduction"
i: 1
type: INT
, required=False), 'target': OpSchema.Attribute(name='target', type=<AttrType.INTS: 7>, description='', default_value=, required=True)}
)

Fixes #476

@justinchuby justinchuby added module: torchlib Related to the torch/aten function lib in development topic: api labels Apr 5, 2023
@justinchuby justinchuby marked this pull request as ready for review April 12, 2023 00:29

import onnx
import onnx.defs

from onnxscript import irbuilder, sourceinfo
from onnxscript import irbuilder, sourceinfo, type_annotation

Check notice

Code scanning / CodeQL

Cyclic import

Import of module [onnxscript.irbuilder](1) begins an import cycle.
@justinchuby justinchuby requested review from BowenBao and jcwchen April 12, 2023 00:56
@justinchuby
Copy link
Collaborator Author

@gramalingam I also see AttrRef, how should it be used?

@codecov
Copy link

codecov bot commented Apr 12, 2023

Codecov Report

Merging #594 (7207107) into main (f1591c7) will decrease coverage by 0.08%.
The diff coverage is 89.04%.

@@            Coverage Diff             @@
##             main     #594      +/-   ##
==========================================
- Coverage   74.20%   74.13%   -0.08%     
==========================================
  Files         107      107              
  Lines       11302    11401      +99     
  Branches     1177     1197      +20     
==========================================
+ Hits         8387     8452      +65     
- Misses       2592     2619      +27     
- Partials      323      330       +7     
Impacted Files Coverage Δ
onnxscript/_internal/version_utils.py 100.00% <ø> (ø)
onnxscript/converter.py 81.74% <ø> (-1.20%) ⬇️
onnxscript/type_annotation.py 67.00% <55.17%> (-5.98%) ⬇️
onnxscript/irbuilder.py 77.12% <83.33%> (+0.17%) ⬆️
onnxscript/onnx_types.py 93.57% <98.07%> (+1.26%) ⬆️
onnxscript/autocast.py 92.85% <100.00%> (ø)
...pt/function_libs/torch_aten/graph_building_test.py 89.09% <100.00%> (ø)
...s/function_libs/torch_aten/ops_correctness_test.py 86.35% <100.00%> (-2.47%) ⬇️
onnxscript/tests/functions/onnxfns1A_test.py 85.45% <100.00%> (+0.26%) ⬆️
onnxscript/tests/functions/onnxfns2_test.py 94.73% <100.00%> (+0.14%) ⬆️
... and 2 more

... and 1 file with indirect coverage changes

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:

# Currently, only tensor types are supported. Need to expand support for other ONNX types.
ONNXType = TensorType

ALL_TENSOR_TYPES = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please sort them in alphabetical order for better readness?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx", "typing_extensions"]
dependencies = ["numpy", "onnx>=1.13", "typing_extensions"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we specify the minimum onnx version here?

@@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
]
dependencies = ["numpy", "onnx", "typing_extensions"]
dependencies = ["numpy", "onnx>=1.13", "typing_extensions"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for? I see OpSchema is no older than 14?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am planning to bump this when 14 is released. For now I can revert I guess

self._param_schemas: Optional[tuple[ParamSchema, ...]] = None

def __call__(self, *args, **kwargs):
# FIXME(after #225): Move import to the top of the file.
from onnxscript import evaluator # pylint: disable=import-outside-toplevel

return evaluator.default().eval(self.get_schema(), args, kwargs)
schema = self.get_schema()
assert schema is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be None for an Op? What is the case here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the op doesn't have a schema this will fail I think. That would be the custom ops. Here we expect schema is always not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I forgot custom ops. How does custom op get in here? I thought onnx-script didn't support non-ONNX ops. If it's a custom op built with non onnx-script way (not compmosed of ONNX standard op), how does onnx-script process it? I thought we could only handle it in exporter side.

fn: IRFunction,
varname: str,
attribute_type: onnx.AttributeProto.AttributeType,
default_value: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better annotation we can use than Any?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. Let me refine that

proto = onnx.AttributeProto()
proto.name = varname
proto.type = attribute_type
fn.add_attr_parameter(IRAttributeValue(proto), has_default=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think has_defualt can be an attribute of IRAttributeValue? I think that's more straightforward.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great suggestion - I thought of this too. Will make an update

@@ -203,3 +239,22 @@ def onnx_type_to_onnxscript_repr(onnx_type: onnx.TypeProto) -> str:

# Currently, only tensor types are supported. Need to expand support for other ONNX types.
ONNXType = TensorType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure. Maybe?

@justinchuby justinchuby marked this pull request as draft April 12, 2023 17:33
@justinchuby
Copy link
Collaborator Author

Closed in favor of #626

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development topic: api
Projects
None yet
Development

Successfully merging this pull request may close these issues.

OpSchema in Python
3 participants