Skip to content

Implement collective gather op #9435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

bfolie
Copy link
Collaborator

@bfolie bfolie commented Jul 1, 2025

@bfolie bfolie requested a review from pgmoka July 1, 2025 23:08
Comment on lines -106 to -128
@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just moving this test into the appropriate class

Comment on lines +339 to +360
@staticmethod
def _scatter():
dist.init_process_group("xla", init_method='xla://')
device = torch_xla.device()
world_size = xr.world_size()
tensors = None
if xr.global_ordinal() == 0:
tensors = [
torch.tensor([i], device=device, dtype=torch.float)
for i in range(world_size)
]

output_tensor = torch.tensor([-1], dtype=torch.float, device=device)
dist.scatter(output_tensor, tensors, src=0)
return output_tensor.cpu()

def test_scatter(self):
"""self._scatter instantiates a list of tensors [[0], [1], ..., [n-1]]
on device 0, then scatters it. Device i should therefore receive [i]."""
results = pjrt.run_multiprocess(self._scatter)
for ordinal, value in results.items():
np.testing.assert_array_equal(value, [ordinal])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied from above

@bfolie
Copy link
Collaborator Author

bfolie commented Jul 2, 2025

Failing tests are expected until the TPU CI cluster is updated to use python 3.12. See #9434

@bfolie bfolie requested a review from benawilson July 2, 2025 19:31
Comment on lines 269 to 270
input_for_all_gather, dim=0, groups=self._mesh, pin_layout=False)
# Syncing is required to keep the heterogeneous copying below at the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Add space between code line and comment.

rank = xr.global_ordinal()

for i, input_tensor in enumerate(input_tensor_list):
is_scalar = input_tensor.dim() == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is happening during each iteration of the loop.

If input_tensor_list is not empty, could we not do something like is_scalar = input_tensor_list[0].dim() == 0?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not guaranteed that every element of input_tensor_list has the same size. They're basically independent gather operations.

if rank == opts.rootRank:
return _ret_work(output_tensors_list)
else:
return _ret_work([[]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here? From this base reading, it is that if rank != opts.rootRank, return an empty.

If that is the case, could we add something like:

if rank != opts.rootRank:
      return _ret_work([[]])

In the beginning of the function, and avoid this if split, as well as the one above?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point -- that would make the code simpler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no -- all non-dst ranks still need to call the all_gather with their input and then sync. It's only the copying to the output which is device-specific. Which means we can't return at the beginning of the function

@bfolie bfolie enabled auto-merge (squash) July 9, 2025 20:06
@bfolie bfolie disabled auto-merge July 9, 2025 20:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants