We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 202e6c1 commit 764eee4Copy full SHA for 764eee4
src/accelerate/utils/operations.py
@@ -316,6 +316,10 @@ def _gpu_gather(tensor):
316
state = PartialState()
317
gather_op = torch.distributed.all_gather_into_tensor
318
319
+ # FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
320
+ if state.device.type == "xpu":
321
+ torch.xpu.synchronize()
322
+
323
def _gpu_gather_one(tensor):
324
if tensor.ndim == 0:
325
tensor = tensor.clone()[None]
0 commit comments