Skip to content

Commit c6fb9e0

Browse files
committed
Update on "[ONNX] Introduce FX-ONNX dispatcher"
Needs microsoft/onnxscript#721 The current FX exporter is using manually maintained dictionary to map ATen op to its OnnxFunction. However, the issue arises when ATen op has overloads or OnnxFunction has overloads, which is not resolvable by the one to one mapping . For example, `aten::arange` has onverloads: `aten::arange.start` and `aten::arange.start_step`, or for `aten::argmax`, torchlib provides two function: aten_argmax, and aten_argmax_dim. This PR utilizes newly introduced [ONNX OpSchema](microsoft/onnxscript#626) to match the input arguments of an ATen operator to find the correct overload. ### OnnxRegistry Heavily reference on [TorchScript Registry](#84382). The only difference is that in FX registry, an ATen operator with specific opset version is mapped to a list of overloaded functions. * No longer use global registry. The registry is initialized in `ResolvedExportOptions` with torchlib, and will be exposed to users in the future. * Multiple opset version layer is kept through `_SymbolicFunctionGroup` , but torchlib now only supports 18. * Basic API of custom operator support: `register`, `unregister`, and `is_register_op` are kept for future development. To further complete them, the follow-up PRs should address: - How to allow users to remove/override specific overload? Using OpSchema to differentiate? - User registers a new overload with the same OpSchema as one of registered overload. ### OnnxDispatcher Dispatch ATen operators to the matched overload by comparing OpSchema with input arguments. * `dispatch` uses `OpSchemaWrapper` to compare data types to find matched overload. * `dispatch_opset_version` is referenced from #84382 and kept, but torchlib doesn't support opset version != 18. * To include more supports: the follow-up PRs should address: - How to add op.Cast with autocast? In torchlib or converter? - Does Type promotion need OpSchema consultant? If not, it can be done in graph-level pass. Co-authored-by: Justin Chu <justinchubymicrosoft.com> [ghstack-poisoned]
2 parents e535910 + 65fc3e5 commit c6fb9e0

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/onnx/_internal/fx/function_dispatcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def dispatch(
8484
node: torch.fx.Node,
8585
onnx_args: Sequence,
8686
onnx_kwargs: Mapping,
87-
) -> Union[onnxscript.OnnxFunction, onnxscript.TracedOnnxFunction]:
87+
) -> Union["onnxscript.OnnxFunction", "onnxscript.TracedOnnxFunction"]:
8888
"""Dispatches an ONNX function based on the given FX node, arguments, and keyword arguments.
8989
9090
Args:

0 commit comments

Comments
 (0)