-
Notifications
You must be signed in to change notification settings - Fork 65
Fix ops type/name and adapt API to support torch dispatcher/registry | feat(torchlib) #721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ops type/name and adapt API to support torch dispatcher/registry | feat(torchlib) #721
Conversation
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the following PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the following PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the following PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the following PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - A function `nearest_match` should be implemented with autocasting. - Type promotion Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - A function `nearest_match` should be implemented with autocasting. - Type promotion Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
@@ -130,13 +134,21 @@ def shape(self, shape: Tuple[int | None, ...]): | |||
@property | |||
def dtype(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is it possible to annotate return type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -49,6 +49,9 @@ def __iter__(self): | |||
def __repr__(self): | |||
return repr(self._registry) | |||
|
|||
def items(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: return type annotation?
@@ -130,13 +134,21 @@ def shape(self, shape: Tuple[int | None, ...]): | |||
@property | |||
def dtype(self): | |||
# TODO: Return numpy dtype | |||
return _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] | |||
torch_dtype = _type_utils.JitScalarType.from_value( # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 Is it possible to extend _type_utils.JitScalarType.from_value
to handle sequence type properly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, you don't have a torch.dtype
for sequence to be passed in from dtype.setter
to begin with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Frankly feels it's a bit overloaded and doesn't seem to match the design for this class, to be also covering sequence type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add another attribute?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One of alternative way that I can think of is that we give up this dtype, and try to relax match_schema
in a way that it finds the best match instead of "the 100% match", as I found some inputs are just None anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reverted
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `OpSchemaWrapper` wrap the onnx schema, and record matching score. * `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information. ### OpSchemaWrapper - Matching Score Mechanism #### The matching score system: This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto. 1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings. 2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF. #### Example of overloads 1. Different types: Caused by the difference between the ONNX spec and PyTorch. The matching system finds the correct one. ```python @torch_op("aten::mul") def aten_mul(self: TReal, other: TReal) -> TReal: ... @torch_op("aten::mul") def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL: ... ``` 2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None" ```python @torch_op("aten::argmax", trace_only=True) def aten_argmax( self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False ) -> TrealOrUInt8: ... @torch_op("aten::argmax", private=True) def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8: ... ``` This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function. 3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None. ```python @torch_op("aten::new_full") def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor: ... @torch_op("aten::new_full") def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor: ... ``` Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one. 4. `None` and `[]` and `NoneType` are considered failing the match. 5. Two functions have the same score is recorded into SARIFs. ### TODOs 1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563. 5. torchlib should provide the "opset version" to OnnxRegistry. 7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed. Co-authored-by: Justin Chu <[email protected]> Pull Request resolved: #100660 Approved by: https://github.com/thiagocrepaldi
Improve logic and fix bugs in Registry to provide completed torchlib to converter.To enable dispatcher in converter: pytorch/pytorch#100660
To enable dispatcher in torch converter with utilizing OpSchema, I am annotating all of input arguments from FX graph except seq(tensor), which requires us to address the non-annotated type (Sym value and ONNX Sequence):
https://github.com/pytorch/pytorch/blob/9a811d1df23acd8285e72f938c3919a4993bca69/torch/onnx/_internal/fx/passes/fx_to_onnxscript.py#L197-L208
NOTE:
For Sym value, as we are treating them as normal TorchScriptTensor with aten_add , aten_sub, ... etc, they would be annotated as int64 (SymInt), and float32 (SymFloat).
For ONNX sequence, these are not defined in torch, nor in torchscript. We return NoneType and dispatcher see NoneType/None/empty list as AllType to relax the match.