-
Notifications
You must be signed in to change notification settings - Fork 551
implement collective all_to_all op #9442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like the implementation has always been there but just not exposed? Any thoughts on why that might be?
|
||
return [t.cpu() for t in output_tensors] | ||
|
||
def test_all_to_all(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you think a performance test might be necessary here to ensure there's no unforseen bottleneck creating latency/throughput issues?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Stage 1 of this project is to improve op coverage. Stage 2 is to rigorously benchmark how the collective ops scale and identify any bottlenecks. I've been reassigned to work on the new repo so I won't be doing stage 2, at least in the near future.
@@ -359,6 +359,36 @@ def test_all_to_all_single(self, use_dynamo): | |||
expected.sort().values), | |||
f"Got {val}, expected {expected}") | |||
|
|||
@staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are tests in pjrt/
designed to run on some non-trival distributed setup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases yes. The tests in this file are run by tpu/run_tests.sh and expect multiple TPUs. Some of the other files in pjrt/
are part of the basic test suite (example)
How do you mean? There are two torch.distributed functions: |
#9315