-
Notifications
You must be signed in to change notification settings - Fork 551
Implement collective gather op #9435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@staticmethod | ||
def _scatter(): | ||
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
tensors = None | ||
if xr.global_ordinal() == 0: | ||
tensors = [ | ||
torch.tensor([i], device=device, dtype=torch.float) | ||
for i in range(world_size) | ||
] | ||
|
||
output_tensor = torch.tensor([-1], dtype=torch.float, device=device) | ||
dist.scatter(output_tensor, tensors, src=0) | ||
return output_tensor.cpu() | ||
|
||
def test_scatter(self): | ||
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] | ||
on device 0, then scatters it. Device i should therefore receive [i].""" | ||
results = pjrt.run_multiprocess(self._scatter) | ||
for ordinal, value in results.items(): | ||
np.testing.assert_array_equal(value, [ordinal]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just moving this test into the appropriate class
@staticmethod | ||
def _scatter(): | ||
dist.init_process_group("xla", init_method='xla://') | ||
device = torch_xla.device() | ||
world_size = xr.world_size() | ||
tensors = None | ||
if xr.global_ordinal() == 0: | ||
tensors = [ | ||
torch.tensor([i], device=device, dtype=torch.float) | ||
for i in range(world_size) | ||
] | ||
|
||
output_tensor = torch.tensor([-1], dtype=torch.float, device=device) | ||
dist.scatter(output_tensor, tensors, src=0) | ||
return output_tensor.cpu() | ||
|
||
def test_scatter(self): | ||
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]] | ||
on device 0, then scatters it. Device i should therefore receive [i].""" | ||
results = pjrt.run_multiprocess(self._scatter) | ||
for ordinal, value in results.items(): | ||
np.testing.assert_array_equal(value, [ordinal]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copied from above
Failing tests are expected until the TPU CI cluster is updated to use python 3.12. See #9434 |
torch_xla/distributed/xla_backend.py
Outdated
input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False) | ||
# Syncing is required to keep the heterogeneous copying below at the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Add space between code line and comment.
rank = xr.global_ordinal() | ||
|
||
for i, input_tensor in enumerate(input_tensor_list): | ||
is_scalar = input_tensor.dim() == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is happening during each iteration of the loop.
If input_tensor_list is not empty, could we not do something like is_scalar = input_tensor_list[0].dim() == 0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not guaranteed that every element of input_tensor_list
has the same size. They're basically independent gather operations.
if rank == opts.rootRank: | ||
return _ret_work(output_tensors_list) | ||
else: | ||
return _ret_work([[]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is going on here? From this base reading, it is that if rank != opts.rootRank
, return an empty.
If that is the case, could we add something like:
if rank != opts.rootRank:
return _ret_work([[]])
In the beginning of the function, and avoid this if split, as well as the one above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point -- that would make the code simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no -- all non-dst ranks still need to call the all_gather with their input and then sync. It's only the copying to the output which is device-specific. Which means we can't return at the beginning of the function
#9315