Skip to content

api to grab base seed as device data #4293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2022

Conversation

shunting314
Copy link
Collaborator

@shunting314 shunting314 commented Dec 7, 2022

First of all, I'd like to reflect my understanding of how random number is used in torchxla:

When tracing modules that uses random numbers like DropOut, torchxla will need call XLAGraphExecutor::GetRngSeed to grab the current seed and put the seed into the traced graph. The API has a side effect of updating some DeviceContext object so next time when it's called, the next (different) seed is returned. Here is an example of the bernoulli op used by Dropout: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L782 . In this example, we can see that the current seed is grabbed as an IR node and used to create the Bernoulli IR node.

In our dynamo/torchxla bridge we reused the traced graph. To make sure we can maintain the numerical correctness (for benchmarking), the bridge should have the same behavior as regular torchxla. This PR change GetRngSeedData method to do the same thing as GetRngSeed but guarantee to return a DeviceData node. In the bridge whenever we need prepare a graph input which is identified as a seed, we should call GetRngSeedData.

@shunting314 shunting314 force-pushed the api_grab_running_seed_as_device_data branch from 1c732ef to cb8dc8e Compare December 7, 2022 07:45
@shunting314 shunting314 marked this pull request as ready for review December 7, 2022 07:56
@shunting314 shunting314 changed the title api to grad running seed as device data api to grab running seed as device data Dec 7, 2022
@shunting314 shunting314 requested a review from JackCaoG December 7, 2022 07:57
@shunting314
Copy link
Collaborator Author

I confirmed the fix on all the 9 models we are looking (exclude squeezenet since it fall back on avg_pool2d with a workaroud for the max_pool2d issue.)

@JackCaoG JackCaoG added the dynamo label Dec 7, 2022
@shunting314 shunting314 force-pushed the api_grab_running_seed_as_device_data branch from cb8dc8e to d4af001 Compare December 8, 2022 00:06
@shunting314 shunting314 requested a review from JackCaoG December 8, 2022 00:09
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@shunting314 shunting314 force-pushed the api_grab_running_seed_as_device_data branch from d4af001 to 16d349e Compare December 8, 2022 01:58
@shunting314 shunting314 merged commit feba771 into master Dec 8, 2022
@shunting314 shunting314 changed the title api to grab running seed as device data api to grab base seed as device data Dec 8, 2022
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Jan 5, 2023
We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla.

Training is trickier than inference and we may not expect much perf gains since
1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations.
2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency.
3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces.

But we still want to add training support to dynamo/torchxla to make the work complete.

We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above.

Results:

With '--iterations-per-run' equals to 1, here are the perf numbers:
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |             0.91   |                0.959    |
+-------------------------+--------------------+-------------------------+
| resnet50                |             0.917  |                0.932    |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |             0.912  |                0.905    |
+-------------------------+--------------------+-------------------------+
| alexnet                 |             1.038  |                0.974    |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |             0.881  |                0.835    |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |             0.903  |                0.931    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |             0.914  |                0.967    |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |             1.359  |                0.84     |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |             1.288  |                0.893    |
+-------------------------+--------------------+-------------------------+
| geomean                 |             1.0006 |                0.913794 |
+-------------------------+--------------------+-------------------------+
```

Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run'

NOTE:
In torchbench.py I added the following code to do a few workaround:
```
from myscripts import workaround # TODO will remove this line before landing
```

Here are the content of workaround.py:
```
import torch
from torch import nn
import os

# override max_pool2d with avg_pool2d
if os.environ.get("REPLACE_MAXPOOL", "0") == "1":
    torch.nn.MaxPool2d = torch.nn.AvgPool2d

```

It work around a few issues we found
1. MaxPool2d does not work for training in dynamo/torchxla: pytorch/torchdynamo#1837 . WIP fix from Brian in #90226 , https://github.com/pytorch/xla/pull/4276/files (WIP)
2. recent change ( this PR #88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in pytorch/xla#4282 (comment) . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd)
3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: pytorch/xla#4293 (confirmed the fix)

Example command:
```
REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16
```

Pull Request resolved: #88449
Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants