-
Notifications
You must be signed in to change notification settings - Fork 553
api to grab base seed as device data #4293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1c732ef
to
cb8dc8e
Compare
This was referenced Dec 7, 2022
I confirmed the fix on all the 9 models we are looking (exclude squeezenet since it fall back on avg_pool2d with a workaroud for the max_pool2d issue.) |
JackCaoG
reviewed
Dec 7, 2022
cb8dc8e
to
d4af001
Compare
JackCaoG
approved these changes
Dec 8, 2022
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
d4af001
to
16d349e
Compare
pytorchmergebot
pushed a commit
to pytorch/pytorch
that referenced
this pull request
Jan 5, 2023
We've already shown some promising perf result by integrating dynamo with torchxla for inference. To provide consistent UX for training and for inference, in this PR we try to enable training for dynamo/torchxla. Training is trickier than inference and we may not expect much perf gains since 1. in training case, torchxla only generate a single combined graph for fwd/bwd/optimizer while in `torchxla_trace_once` bridge we added in dynamo, due to how AOT_Autograd works, we will generate 3 graphs: one for forward, one for backward and one for the optimizer. XLA favors larger graph to do more optimizations. 2. in training case, tracing overhead can be overlapped with computation. Tracing overhead is not as a big deal for training as for inference. After all training cares more about throughput while inference cares more about latency. 3. in training case, people can increase batch size to 'mitigate' the tracing overhead. Increase batch size does not change tracing overhead, thus it shows like the tracing overhead 'per example' reduces. But we still want to add training support to dynamo/torchxla to make the work complete. We added '--iterations-per-run' argument to control how may iterations we do per measure/device sync. This is to understand the impact of item 2 above. Results: With '--iterations-per-run' equals to 1, here are the perf numbers: ``` +-------------------------+--------------------+-------------------------+ | Model | XLA (trace once) | XLA (trace everytime) | +=========================+====================+=========================+ | resnet18 | 0.91 | 0.959 | +-------------------------+--------------------+-------------------------+ | resnet50 | 0.917 | 0.932 | +-------------------------+--------------------+-------------------------+ | resnext50_32x4d | 0.912 | 0.905 | +-------------------------+--------------------+-------------------------+ | alexnet | 1.038 | 0.974 | +-------------------------+--------------------+-------------------------+ | mobilenet_v2 | 0.881 | 0.835 | +-------------------------+--------------------+-------------------------+ | mnasnet1_0 | 0.903 | 0.931 | +-------------------------+--------------------+-------------------------+ | vgg16 | 0.914 | 0.967 | +-------------------------+--------------------+-------------------------+ | BERT_pytorch | 1.359 | 0.84 | +-------------------------+--------------------+-------------------------+ | timm_vision_transformer | 1.288 | 0.893 | +-------------------------+--------------------+-------------------------+ | geomean | 1.0006 | 0.913794 | +-------------------------+--------------------+-------------------------+ ``` Overall it looks like graph break indeed cause perf loss. But for BERT_pytorch and timm_vision_transformer we still see perf gain. We need do more experiments with larger '--iterations-per-run' NOTE: In torchbench.py I added the following code to do a few workaround: ``` from myscripts import workaround # TODO will remove this line before landing ``` Here are the content of workaround.py: ``` import torch from torch import nn import os # override max_pool2d with avg_pool2d if os.environ.get("REPLACE_MAXPOOL", "0") == "1": torch.nn.MaxPool2d = torch.nn.AvgPool2d ``` It work around a few issues we found 1. MaxPool2d does not work for training in dynamo/torchxla: pytorch/torchdynamo#1837 . WIP fix from Brian in #90226 , https://github.com/pytorch/xla/pull/4276/files (WIP) 2. recent change ( this PR #88697 ) in op decomposition cause batch_norm ops to fallback in torchxla. Fix from jack in pytorch/xla#4282 (comment) . (confirmed the fix after adding Deduper to handle duplicated return from fx graph generated by AOTAutograd) 3. we have issue to handle dropout because of random seed out of sync issue. Here is the fix: pytorch/xla#4293 (confirmed the fix) Example command: ``` REPLACE_MAXPOOL=1 USE_FAKE_TENSOR=0 GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --training --backend=aot_torchxla_trace_once --only vgg16 ``` Pull Request resolved: #88449 Approved by: https://github.com/wconstab, https://github.com/qihqi, https://github.com/malfet
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
First of all, I'd like to reflect my understanding of how random number is used in torchxla:
When tracing modules that uses random numbers like DropOut, torchxla will need call
XLAGraphExecutor::GetRngSeed
to grab the current seed and put the seed into the traced graph. The API has a side effect of updating some DeviceContext object so next time when it's called, the next (different) seed is returned. Here is an example of the bernoulli op used by Dropout: https://github.com/pytorch/xla/blob/master/torch_xla/csrc/tensor_methods.cpp#L782 . In this example, we can see that the current seed is grabbed as an IR node and used to create the Bernoulli IR node.In our dynamo/torchxla bridge we reused the traced graph. To make sure we can maintain the numerical correctness (for benchmarking), the bridge should have the same behavior as regular torchxla. This PR change
GetRngSeedData
method to do the same thing asGetRngSeed
but guarantee to return a DeviceData node. In the bridge whenever we need prepare a graph input which is identified as a seed, we should callGetRngSeedData
.