-
Notifications
You must be signed in to change notification settings - Fork 65
[torchlib] trace_ok #1095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Discussed with @BowenBao. With the passes Bowen has created we should be able to fold away the if branches and hit comparable performance. Given this it may make more sense to follow a profile guided process to avoid premature optimization on the torchlib side. I will instead:
cc @thiagocrepaldi happy to sync on this sometime this week with @BowenBao and others. |
@gramalingam's work on constant folding as a post-export step might do the trick too, so it is fine to table trace_ok for now.
Which passes are we talking about, specifically? Can you provide links to the code? ORT core mentioned they noticed graphs with several If nodes in it. I can check the code and compare with a graph and see what is up
Do you foresee any possible issue constant folding this later? Will it create more branching? Indeed this seems a high-value task to work on
Awesome. I was wondering whether we could generalize the idea from this op to a wider scope. That is, how practical would be to create "hardware-specific" or "scenario" (e.g. fp32 vs fp16, static input vs dynamic input) overloads of (a few?) operators we care about? this way, users can export either the most generic version (default behavior) or the most specialized variant?. This is kind of a short-circuit for the whole optimization tooling we need to work on, but maybe can help on the short term if it is expected to give gains with reasonable effort)
Let's discuss about that on our weekly |
1: I don't see it affecting constant folding - it should just be another ONNX function defined as the composition as (Size ∘ Shape) |
This change introduces two shared operators `Rank` and `IsScalar`. They are used to replace the `Size(Shape())` pattern for code reuse and readability. I used a hack to always include these shared functions in the model proto because without #834 we cannot dynamically add these functions to the model as they are used. I added a TODO for this. The first usage is in `aten_all`. I will update the rest of the functions in a separate PR. #1095
Hi @justinchuby, As an ONNX backend-developer for onnx-tensorrt, similar to the reports from @thiagocrepaldi we have seen a large increase in models with I see there are current efforts to do the constant-folding behavior, and since I'm new to ONNXScript and dynamo export I would like your opinion on the recommended paths for backends.
Thanks, and I look forward to the future of ONNXScript export! |
Thanks for your questions!
|
Thanks for the response - @BowenBao are you able to share the timeline of when this folding option will arrive to the dynamo-exporter? I will look more into function overrides - currently the most common issue we see are extra conditionals, we'll do more analysis on other model to see if specific function overrides are required. |
This work is at the very early development stage, and we don't have a timeline just yet. |
@kevinch-nv Please feel free to let us know your progress and if you hit any roadblocks along the way |
No updates yet, I'm still gathering some data on the models exported through dynamo & ONNXScript. My most pressing question is the timeline of dynamo export replacing the current torchscript export, and the list of features (such as conditional block folding) that are planned to be pushed in prior to the official switch. This will provide some more information to backends, as we would like to avoid regressions between models exported between the two paths and avoid having to implement (potentially) duplicate optimizations. |
@justinchuby is there an example on how we can force a tracing pass for the |
On it |
As it currently stands, one can only trace the whole function by removing the
|
Thanks Justin - I did some more digging and there doesn't seem to be a easy way to force a trace for certain ops without modifying the source. I've noticed that the For the |
General tracing support is tracked by #1083 and #694. This is a (slightly) longer term effort as we need to create a better graph class to support it natively in onnxscript. For the PyTorch ONNX converter however, you would only need to remove the decorator from the custom function for it to be traced. The |
Thanks for the pointers, it's good to know that these things are being worked on!
Not sure what you mean by that, can you provide an example? My concern is that because only |
Yes. Consider onnxscript/onnxscript/function_libs/torch_lib/ops/core.py Lines 254 to 276 in b7f215e
# Imports omitted
def aten_addr(
self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
"""
vec1_shape = op.Constant(value_ints=[-1, 1])
vec2_shape = op.Constant(value_ints=[1, -1])
vec1_reshaped = op.Reshape(vec1, vec1_shape)
vec2_reshaped = op.Reshape(vec2, vec2_shape)
outer = op.MatMul(vec1_reshaped, vec2_reshaped)
# https://github.com/pytorch/pytorch/blob/51664489ba6f6b2343bbec9af9ca99185e2a5dbc/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp#L53-L54
# When beta == 0, values in self should be ignored,
# nans and infs in self should not propagate.
if beta == 0.0:
result = op.Mul(alpha, outer)
else:
result = op.Add(op.Mul(beta, self), op.Mul(alpha, outer))
return result then register with the PyTorch API https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.OnnxRegistry.register_op torch.onnx.register_op(aten_addr, "aten", "addr") |
aten_clamp_min
|
As an effort described in #1095, this PR marks some functions as traceable. [ghstack-poisoned]
…olding branches and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested locally. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…rchlib)" As an effort described in #1095, this PR marks some functions as traceable. [ghstack-poisoned]
…ble option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
@kevinch-nv you may now test with the |
…olding branches and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…g and the traceable option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…rchlib)" As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…ble option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…g and the traceable option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…olding branches and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…rchlib)" As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…ble option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…g and the traceable option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
…rchlib)" As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…olding branches and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…ble option | feat(torchlib)" As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. [ghstack-poisoned]
As an effort described in #1095, this PR marks functions with if branches as traceable. [ghstack-poisoned]
…| feat(torchlib) (#1176) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1178 * #1177 * __->__ #1176 As an effort described in #1095, this PR - adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible. - defined the `traceable` option in the torch_op decorator to mark a function as `traceable`. --------- Co-authored-by: Ti-Tai Wang <[email protected]>
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * #1178 * __->__ #1177 * #1176 As an effort described in #1095, this PR marks functions with if branches as traceable.
…olding branches and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…and castlikes | feat(torchlib)" As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ``` [ghstack-poisoned]
…es | feat(torchlib) (#1178) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1178 As an effort described in #1095, this PR - Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible. - Updates implementation for `addr` for it to be traceable. - Conditionally enabled previously xfailed tests. Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI. E.g. clamp_min now becomes ``` < ir_version: 8, opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1], producer_name: "pytorch", producer_version: "2.2.0" > main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8> { _val_2 = Size (input_0) _val_3 = Shape <start: int = 0> (input_1) _val_4 = Size (_val_3) _val_5 = Constant <value: tensor = int64 {0}> () _val_6 = Equal (_val_2, _val_5) _val_7 = Constant <value: tensor = int64 {0}> () _val_8 = Equal (_val_4, _val_7) _val_9 = Max (input_0, input_1) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > Rank (input) => (return_val) { tmp = Shape (input) return_val = Size (tmp) } < domain: "pkg.onnxscript.torch_lib.common", opset_import: ["" : 18] > IsScalar (input) => (return_val) { tmp = Shape (input) tmp_0 = Size (tmp) tmp_1 = Constant <value_int: int = 0> () return_val = Equal (tmp_0, tmp_1) } ```
Potentially allow a mode to evaluate if conditions on ranks to fold away if conditions.
#1089
The text was updated successfully, but these errors were encountered: