Skip to content

[torchlib] trace_ok #1095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 of 5 tasks
justinchuby opened this issue Oct 18, 2023 · 18 comments
Closed
4 of 5 tasks

[torchlib] trace_ok #1095

justinchuby opened this issue Oct 18, 2023 · 18 comments
Assignees
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Oct 18, 2023

Potentially allow a mode to evaluate if conditions on ranks to fold away if conditions.

#1089

@justinchuby justinchuby added module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion labels Oct 18, 2023
@justinchuby justinchuby self-assigned this Oct 20, 2023
@justinchuby
Copy link
Collaborator Author

justinchuby commented Oct 23, 2023

Discussed with @BowenBao. With the passes Bowen has created we should be able to fold away the if branches and hit comparable performance.

Given this it may make more sense to follow a profile guided process to avoid premature optimization on the torchlib side. I will instead:

  1. Still create the Rank op to simplify the functions and improve readability
  2. Drive [Feature request] field to represent shape/type for nodes within function onnx/onnx#5487
  3. Create Gemm overloads to temporarily fix Linear from PyTorch must map to Gemm in ONNX #1089

cc @thiagocrepaldi happy to sync on this sometime this week with @BowenBao and others.

@thiagocrepaldi
Copy link
Contributor

thiagocrepaldi commented Oct 23, 2023

@gramalingam's work on constant folding as a post-export step might do the trick too, so it is fine to table trace_ok for now.

Discussed with @BowenBao. With the passes Bowen has created we should be able to fold away the if branches and hit comparable performance.

Which passes are we talking about, specifically? Can you provide links to the code? ORT core mentioned they noticed graphs with several If nodes in it. I can check the code and compare with a graph and see what is up

Given this it may make more sense to follow a profile guided process to avoid premature optimization on the torchlib side. I will instead:

  1. Still create the Rank op to simplify the functions and improve readability

Do you foresee any possible issue constant folding this later? Will it create more branching?

  1. Drive [Feature request] field to represent shape/type for nodes within function onnx/onnx#5487

Indeed this seems a high-value task to work on

  1. Create Gemm overloads to temporarily fix Linear from PyTorch must map to Gemm in ONNX #1089

Awesome. I was wondering whether we could generalize the idea from this op to a wider scope. That is, how practical would be to create "hardware-specific" or "scenario" (e.g. fp32 vs fp16, static input vs dynamic input) overloads of (a few?) operators we care about? this way, users can export either the most generic version (default behavior) or the most specialized variant?. This is kind of a short-circuit for the whole optimization tooling we need to work on, but maybe can help on the short term if it is expected to give gains with reasonable effort)

cc @thiagocrepaldi happy to sync on this sometime this week with @BowenBao and others.

Let's discuss about that on our weekly

@justinchuby
Copy link
Collaborator Author

justinchuby commented Oct 23, 2023

1: I don't see it affecting constant folding - it should just be another ONNX function defined as the composition as (Size ∘ Shape)
3: Good question. I will need to think more.

justinchuby added a commit that referenced this issue Oct 24, 2023
This change introduces two shared operators `Rank` and `IsScalar`. They
are used to replace the `Size(Shape())` pattern for code reuse and
readability.

I used a hack to always include these shared functions in the model
proto because without #834 we cannot dynamically add these functions to
the model as they are used. I added a TODO for this.

The first usage is in `aten_all`. I will update the rest of the
functions in a separate PR.

#1095
@kevinch-nv
Copy link

Hi @justinchuby,

As an ONNX backend-developer for onnx-tensorrt, similar to the reports from @thiagocrepaldi we have seen a large increase in models with If conditionals generated through the dynamo torch exporter using ONNXScript. The majority of these seem to be based off the type / rank of the input shape (e.g. aten::t), which as others have mentioned can be constant-folded out if the shape information is correctly propagated to the exporter.

I see there are current efforts to do the constant-folding behavior, and since I'm new to ONNXScript and dynamo export I would like your opinion on the recommended paths for backends.

  1. Is the plan for ONNXScript to add a trace mode during export to perform this folding for us?
  2. Is there a recommended path for backend developers to provide custom overrides of the default torch_lib functions and implement tracing behavior ourselves?
  3. If not, is it up to the backend or should this functionality be consolidated to say tools like the onnx-simplifier or built-in ONNX graph optimizers?

Thanks, and I look forward to the future of ONNXScript export!

@justinchuby
Copy link
Collaborator Author

Thanks for your questions!

  1. The current plan is for ONNX/PyTorch to include an option for folding. It will become a standard facility in ONNX. @BowenBao has more info on this.
  2. There are ways to override functions in torch.onnx.dynamo_export. https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.OnnxRegistry.register_op. You are welcome to create traced functions for your use case; alternatively voice your need in this repo so that we know what features will use helpful.
  3. For graph rewriting, we plan to have more robust tooling for that as well.

@kevinch-nv
Copy link

Thanks for the response - @BowenBao are you able to share the timeline of when this folding option will arrive to the dynamo-exporter?

I will look more into function overrides - currently the most common issue we see are extra conditionals, we'll do more analysis on other model to see if specific function overrides are required.

@thiagocrepaldi
Copy link
Contributor

Thanks for the response - @BowenBao are you able to share the timeline of when this folding option will arrive to the dynamo-exporter?

I will look more into function overrides - currently the most common issue we see are extra conditionals, we'll do more analysis on other model to see if specific function overrides are required.

This work is at the very early development stage, and we don't have a timeline just yet.

@justinchuby
Copy link
Collaborator Author

@kevinch-nv Please feel free to let us know your progress and if you hit any roadblocks along the way

@kevinch-nv
Copy link

No updates yet, I'm still gathering some data on the models exported through dynamo & ONNXScript.

My most pressing question is the timeline of dynamo export replacing the current torchscript export, and the list of features (such as conditional block folding) that are planned to be pushed in prior to the official switch.

This will provide some more information to backends, as we would like to avoid regressions between models exported between the two paths and avoid having to implement (potentially) duplicate optimizations.

@kevinch-nv
Copy link

@justinchuby is there an example on how we can force a tracing pass for the Rank() operation to fold down conditional branches on the export side? I'm able to overwrite default functions with a custom onnx registry, but it's still unclear tracing can be run to avoid these conditionals.

@justinchuby
Copy link
Collaborator Author

On it

@justinchuby
Copy link
Collaborator Author

As it currently stands, one can only trace the whole function by removing the @sciprt() decorator to a function. To force onnxscript to evaluate Rank(), consider overwriting the eval_function method in TorchScriptTracingEvaluator and create a special case when function is Rank to return an integer instead.

def eval_function( # type: ignore[override]

@kevinch-nv
Copy link

Thanks Justin - I did some more digging and there doesn't seem to be a easy way to force a trace for certain ops without modifying the source.

I've noticed that the @script() decorator must return a OnnxFunction, and there's no @traced_script() equivalent that returns a values.TracedOnnxFunction type, meaning that all function overloads are by default compiled and not traced.

For the torch_op decorator, there is an trace_only option to handle ops that must have certain values as attributes instead of inputs. Can we also provide a similar option for the script() decorator?

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 16, 2023

General tracing support is tracked by #1083 and #694. This is a (slightly) longer term effort as we need to create a better graph class to support it natively in onnxscript.

For the PyTorch ONNX converter however, you would only need to remove the decorator from the custom function for it to be traced. The trace_only option simply wraps the function with metadata and exposes the raw function to the onnxscript evaluator. Not decorating at all should have the same effect.

@kevinch-nv
Copy link

kevinch-nv commented Nov 16, 2023

Thanks for the pointers, it's good to know that these things are being worked on!

For the PyTorch ONNX converter however, you would only need to remove the decorator from the custom function for it to be traced.

Not sure what you mean by that, can you provide an example? My concern is that because only @script() is exported, we cannot used a traced function to override default compiled torch functions through a custom ONNX registry.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Nov 22, 2023

Yes. Consider

@torch_op("aten::addr")
def aten_addr(
self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
"""addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
"""
vec1_shape = op.Constant(value_ints=[-1, 1])
vec2_shape = op.Constant(value_ints=[1, -1])
vec1_reshaped = op.Reshape(vec1, vec1_shape)
vec2_reshaped = op.Reshape(vec2, vec2_shape)
outer = op.MatMul(vec1_reshaped, vec2_reshaped)
# https://github.com/pytorch/pytorch/blob/51664489ba6f6b2343bbec9af9ca99185e2a5dbc/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp#L53-L54
# When beta == 0, values in self should be ignored,
# nans and infs in self should not propagate.
if beta == 0.0:
result = op.Mul(alpha, outer)
else:
result = op.Add(op.Mul(beta, self), op.Mul(alpha, outer))
return result
, for it to be traced, you would only need to remove the decorator:

# Imports omitted

def aten_addr(
    self: TReal, vec1: TReal, vec2: TReal, beta: float = 1.0, alpha: float = 1.0
) -> TReal:
    """addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor

    Performs the outer-product of vectors vec1 and vec2 and adds it to the matrix input.
    """
    vec1_shape = op.Constant(value_ints=[-1, 1])
    vec2_shape = op.Constant(value_ints=[1, -1])
    vec1_reshaped = op.Reshape(vec1, vec1_shape)
    vec2_reshaped = op.Reshape(vec2, vec2_shape)

    outer = op.MatMul(vec1_reshaped, vec2_reshaped)
    # https://github.com/pytorch/pytorch/blob/51664489ba6f6b2343bbec9af9ca99185e2a5dbc/aten/src/ATen/native/cpu/LinearAlgebraKernel.cpp#L53-L54
    # When beta == 0, values in self should be ignored,
    # nans and infs in self should not propagate.
    if beta == 0.0:
        result = op.Mul(alpha, outer)
    else:
        result = op.Add(op.Mul(beta, self), op.Mul(alpha, outer))

    return result

then register with the PyTorch API https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.OnnxRegistry.register_op

torch.onnx.register_op(aten_addr, "aten", "addr")

@justinchuby
Copy link
Collaborator Author

aten_clamp_min

<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}

justinchuby added a commit that referenced this issue Nov 22, 2023
As an effort described in #1095, this PR marks some functions as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 22, 2023
…olding branches and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested locally.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 22, 2023
…rchlib)"


As an effort described in #1095, this PR marks some functions as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…ble option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
@justinchuby
Copy link
Collaborator Author

@kevinch-nv you may now test with the TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1 flag by setting the env var.

justinchuby added a commit that referenced this issue Nov 23, 2023
…olding branches and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…g and the traceable option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…rchlib)"


As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…ble option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…g and the traceable option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…olding branches and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…rchlib)"


As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…ble option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
…and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 23, 2023
As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
…g and the traceable option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
…rchlib)"


As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
…olding branches and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
…and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
…ble option | feat(torchlib)"


As an effort described in #1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a function as `traceable`.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 27, 2023
As an effort described in #1095, this PR marks functions with if branches as traceable.

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 28, 2023
…| feat(torchlib) (#1176)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1178
* #1177
* __->__ #1176

As an effort described in
#1095, this PR

- adds an experimental `TORCHLIB_EXPERIMENTAL_PREFER_TRACING` flag to
allow the tracer to trace a function when possible.
- defined the `traceable` option in the torch_op decorator to mark a
function as `traceable`.

---------

Co-authored-by: Ti-Tai Wang <[email protected]>
justinchuby added a commit that referenced this issue Nov 28, 2023
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* #1178
* __->__ #1177
* #1176

As an effort described in
#1095, this PR marks
functions with if branches as traceable.
justinchuby added a commit that referenced this issue Nov 28, 2023
…olding branches and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 28, 2023
…and castlikes | feat(torchlib)"



As an effort described in #1095, this PR

- Implements the experimental evaluator for folding branches and castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```

[ghstack-poisoned]
justinchuby added a commit that referenced this issue Nov 29, 2023
…es | feat(torchlib) (#1178)

Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #1178


As an effort described in
#1095, this PR

- Implements the experimental evaluator for folding branches and
castlikes so that they are eagerly evaluated when possible.
- Updates implementation for `addr` for it to be traceable.
- Conditionally enabled previously xfailed tests.

Set `TORCHLIB_EXPERIMENTAL_PREFER_TRACING=1` and tested in CI.

E.g. clamp_min now becomes

```
<
   ir_version: 8,
   opset_import: ["" : 18, "pkg.onnxscript.torch_lib.common" : 1],
   producer_name: "pytorch",
   producer_version: "2.2.0"
>
main_graph (int32[5,10,5] input_0, int32[10,5] input_1) => (int32[5,10,5] _val_9) 
   <int64 _val_2, int64[2] _val_3, int64 _val_4, int64 _val_5, bool _val_6, int64 _val_7, bool _val_8>
{
   _val_2 = Size (input_0)
   _val_3 = Shape <start: int = 0> (input_1)
   _val_4 = Size (_val_3)
   _val_5 = Constant <value: tensor = int64 {0}> ()
   _val_6 = Equal (_val_2, _val_5)
   _val_7 = Constant <value: tensor = int64 {0}> ()
   _val_8 = Equal (_val_4, _val_7)
   _val_9 = Max (input_0, input_1)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
Rank (input) => (return_val)
{
   tmp = Shape (input)
   return_val = Size (tmp)
}
<
  domain: "pkg.onnxscript.torch_lib.common",
  opset_import: ["" : 18]
>
IsScalar (input) => (return_val)
{
   tmp = Shape (input)
   tmp_0 = Size (tmp)
   tmp_1 = Constant <value_int: int = 0> ()
   return_val = Equal (tmp_0, tmp_1)
}
```
@justinchuby justinchuby moved this to In Progress in My List Nov 29, 2023
@justinchuby justinchuby self-assigned this Apr 19, 2024
@github-project-automation github-project-automation bot moved this from In Progress to Done in My List Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: torchlib Related to the torch/aten function lib in development topic: discussion For discussion
Projects
None yet
Development

No branches or pull requests

3 participants