Skip to content

PyTorch FX exporter dispatching and type promotion scenarios #563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
2 of 4 tasks
justinchuby opened this issue Mar 27, 2023 · 13 comments
Closed
2 of 4 tasks

PyTorch FX exporter dispatching and type promotion scenarios #563

justinchuby opened this issue Mar 27, 2023 · 13 comments
Assignees
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Mar 27, 2023

  1. An ATen op can be dispatched to multiple onnx-script implementations:

aten::aten_scaled_dot_product_attention ->

def aten_scaled_dot_product_attention(
    query: TFloat,
    key: TFloat,
    value: TFloat,
    attn_mask: Optional[TFloat] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
): ...

def aten_scaled_dot_product_attention_bool_mask(
    query: TFloat,
    key: TFloat,
    value: TFloat,
    attn_mask: Optional[BOOL] = None,
    dropout_p: float = 0.0,
    is_causal: bool = False,
    scale: Optional[float] = None,
): ...

The exporter will need to dispatch onto one of the implementation based on the dtype of attn_mask.

  1. (Easy) Multiple ATen op overloads may map to a single onnx-script implementation:

Both add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor and add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor map to a single

def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal: ...

Type promotion

Pytorch

  • torch_implicit_integral_promotion: Sometimes we don't need to Cast the output back: Cast(int to float) -> Cos -> output [ONNX] Add supported ops into test_fx_op_consistency - 1st batch pytorch/pytorch#100265 (comment)
  • torch_implicit_type_promotion: ONNX does not have automatic type promotion.
    • We can either run an fx pass to create explicit type promotion nodes before converting to ONNX, or
    • Create rules to promote types in the exporter during conversion.
    • The exporter can consult the type_constraints in a function's OpSchema to determine whether it needs to cast inputs to the same type.

ONNX

  • ONNX tends to support less types than PyTorch. When an ONNX function is selected as the dispatch target but it does not support the input type, the exporter needs to insert cast nodes before and after the function. E.g. Cast -> <the torchlib function> -> Cast to mirror the aten op's behavior.
  • The exporter may not be able to dispatch to a function with exact type match but can still make a successful dispatch under case (1). For example, we may have two functions supporting (a) INT64, (b) FLOAT32, and the input is BFLOAT16. The dispatcher should be able to choose (b) as the target and use the process in (1) to insert proper cast nodes.

cc @titaiwangms @thiagocrepaldi @BowenBao

@justinchuby justinchuby added the topic: discussion For discussion label Mar 27, 2023
@justinchuby justinchuby changed the title PyTorch FX exporter dispatcher scenarios PyTorch FX exporter dispatching scenarios Mar 27, 2023
@justinchuby

This comment was marked as duplicate.

@justinchuby justinchuby changed the title PyTorch FX exporter dispatching scenarios PyTorch FX exporter dispatching and type promotion scenarios Apr 25, 2023
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Apr 25, 2023
@BowenBao
Copy link
Contributor

From the perspective of each aten op defines its own type promotion logic, as well as the perspective of eager execution. Could you also discuss the pros/cons/viability of dispatcher/evaluator/OnnxFunction handling automatic type promotion?

titaiwangms added a commit to pytorch/pytorch that referenced this issue May 11, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
2. torchlib should provide the "opset version" to OnnxRegistry.
3. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 11, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
2. torchlib should provide the "opset version" to OnnxRegistry.
3. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
@BowenBao
Copy link
Contributor

Discussed offline with @titaiwangms.

Scratch notes for case of torchlib owning type promotion

  • Identify and maintain a set of type_promotion_kinds. Each type_promotion_kind represents a set of logic of how casting should be done based on arg/output types.
  • Decorate torchlib ops with their respective set of type_promotion_kinds, which gets populated as attribute of OnnxFunction and OnnxTraceFunction instance.
  • Extend Evaluator (or TypePromotionEvaluator) to insert proper cast nodes based on opschema, actual args/output types, and type_promotion_kinds.

titaiwangms added a commit to pytorch/pytorch that referenced this issue May 15, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 15, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 16, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 16, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 17, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 18, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 19, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
titaiwangms added a commit to pytorch/pytorch that referenced this issue May 19, 2023
Needs microsoft/onnxscript#721

The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim.

This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload.

### OnnxRegistry

Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions.

* No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future.
* Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18.
* Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address:
    - How to allow users to remove/override specific overload? Using OpSchema to differentiate?
    - User registers a new overload with the same OpSchema as one of registered overload. 
 
### OnnxDispatcher

Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments.

* `OpSchemaWrapper` wrap the onnx schema, and record matching score.
* `dispatch` uses `OpSchemaWrapper` to compare data types to find the best matched overload. If the match isn't perfect, record warning in diagnostics.
* `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18.
* Because right now (1) OnnxFunction arguments are manually typed, and (2) ORT could unfollow ONNX type spec, we relax the schema match with `matching score system`.
* To include more supports:  the follow-up PRs should address:
    - How to add op.Cast with autocast? In torchlib or converter?
    - The need of type promotion can be captured by dispatcher, but needs OpSchema shows the T1/T2 information.

### OpSchemaWrapper - Matching Score Mechanism

#### The matching score system:
This is a temporary solution to how we target the correct ONNX overloads given that we only have manually annotated arguments (potentially inaccurate schema) and limited supports on AttributeProto.

1. Perfect match exam: If all arguments/kwargs are all matched, return the function without any warnings.
2. Best match exam: The system add the each correct matching input counts orderly, and subtract the symmetrical difference between their attributes to calculate the matching score. And select the one with the highest score in the end. If the selection is not a perfect match, a warning message is sent to SARIF.

#### Example of overloads

1. Different types: Caused by the difference between the ONNX spec and PyTorch.

The matching system finds the correct one.

```python
torch_op("aten::mul")
def aten_mul(self: TReal, other: TReal) -> TReal:
    ...

torch_op("aten::mul")
def aten_mul_bool(self: BOOL, other: BOOL) -> BOOL:
    ...
```

2. Optional dim: caused by unsupported op.OptionalHasElement (will support on opset version == 20). dim could be "None"

```python
torch_op("aten::argmax", trace_only=True)
def aten_argmax(
    self: TrealOrUInt8, dim: Optional[int] = None, keepdim: bool = False
) -> TrealOrUInt8:
    ...

torch_op("aten::argmax", private=True)
def _aten_argmax_dim(self: TrealOrUInt8, dim: int, keepdim: bool = False) -> TrealOrUInt8:
    ...
```

This case is impossible to differentiate, as they both might have dim in kwargs, so in this case, please make sure you turn the one with `dim: int` to private function.

3. Optional dtype: dtype could be "unprovided". The difference from 2 is that dtype would not be None.

```python
torch_op("aten::new_full")
def aten_new_full(self: TTensor, size: INT64, fill_value: TTensor) -> TTensor:
    ...

torch_op("aten::new_full")
def aten_new_full_dtype(self: TTensor, size: INT64, fill_value: TTensor, dtype: int) -> TTensor:
    ...
```

Depends on dtype is provided or not, matching system will dispatch the ATen op to the correct one.

4. `None` and `[]` and `NoneType` are considered failing the match. 

5. Two functions have the same score is recorded into SARIFs. 

### TODOs

1. Type promotion can be captured by dispatcher only if OpSchema can provide it. However, the implementation of "graph-level" pass vs "in-op"" promotion can be further discussed in microsoft/onnxscript#563.
5. torchlib should provide the "opset version" to OnnxRegistry.
7. How to expose OnnxRegistry with custom add/remove ops APIs nneds to be further discussed.

Co-authored-by: Justin Chu <justinchubymicrosoft.com>

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 23, 2023
-----

- Mark nan==nan in assert_close
- Relaxed float16 precision in comparison
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix type in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 23, 2023
-----

- Mark nan==nan in assert_close
- Relaxed float16 precision in comparison
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix type in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 23, 2023
-----

- Mark nan==nan in assert_close
- Relaxed float16 precision in comparison
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix type in `TRealOrUInt8`

[ghstack-poisoned]
@justinchuby
Copy link
Collaborator Author

Type comparison: https://github.com/pytorch/pytorch/blob/23b7035b3c61b0900e900e39ba7e69997a417ea8/torch/_prims_common/__init__.py#L1073

class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
    DEFAULT = (0,)
    NO_OPMATH = (1,)
    INT_TO_FLOAT = (2,)
    ALWAYS_BOOL = (3,)
    COMPLEX_TO_FLOAT = (4,)
    BOOL_TO_LONG = (5,)


class REDUCTION_OUTPUT_TYPE_KIND(Enum):
    SAME = (0,)
    COMPLEX_TO_FLOAT = (1,)  # for complex types outputs corresponding real type
    KEEP_PROMOTED_TYPE = (2,)  # keep output in opmath type, needed for mean
    ALWAYS_BOOL = (3,)

Usage: https://github.com/pytorch/pytorch/blob/23b7035b3c61b0900e900e39ba7e69997a417ea8/torch/_prims_common/__init__.py#L1291-L1333

justinchuby added a commit that referenced this issue Jun 23, 2023
-----

- Mark nan==nan in assert_close
- Relaxed float16 precision in comparison
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix type in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 23, 2023
-----

- Mark nan==nan in assert_close
- Relaxed float16 precision in comparison
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix type in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 24, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 24, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 24, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 24, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 26, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 26, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 26, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 26, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 27, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 27, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 27, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 27, 2023
-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Jun 27, 2023
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #792

-----

- Mark nan==nan in assert_close
- Fix aten_isfinite by allowing the exporter to do type promotion
according to #563 and introduce `TFloatHighPrecision` for it
- Fix typo in `TRealOrUInt8`
@thiagocrepaldi
Copy link
Contributor

@justinchuby should we close this?

@justinchuby
Copy link
Collaborator Author

justinchuby commented Jul 31, 2023

It's half done, with the ONNX part not being urgent right now. I would leave it open.

@thiagocrepaldi
Copy link
Contributor

np.
is it possible to add the missing/completed tasks to help us know what is and is not done?

@justinchuby
Copy link
Collaborator Author

Sure! Done.

@thiagocrepaldi
Copy link
Contributor

@justinchuby should this be closed?

@justinchuby
Copy link
Collaborator Author

Yes. Thanks!

For reference, The two points under ONNX (1) is not needed in practice; (2) is handled by overloads

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion
Projects
None yet
Development

No branches or pull requests

5 participants