Skip to content

Commit 764eee4

Browse files
authored
add xpu synchronize (#3563)
1 parent 202e6c1 commit 764eee4

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/accelerate/utils/operations.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ def _gpu_gather(tensor):
316316
state = PartialState()
317317
gather_op = torch.distributed.all_gather_into_tensor
318318

319+
# FIXME: the below 2 lines are added to work-aound a bug related to INT64 collectives in oneCCL. Remove them once pytorch-2.9 is released.
320+
if state.device.type == "xpu":
321+
torch.xpu.synchronize()
322+
319323
def _gpu_gather_one(tensor):
320324
if tensor.ndim == 0:
321325
tensor = tensor.clone()[None]

0 commit comments

Comments
 (0)