diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py index db45159adfc9..2b1715cf2377 100644 --- a/src/diffusers/models/_modeling_parallel.py +++ b/src/diffusers/models/_modeling_parallel.py @@ -22,6 +22,7 @@ import torch.distributed as dist from ..utils import get_logger +from ..utils.torch_utils import get_device if TYPE_CHECKING: @@ -290,7 +291,14 @@ def gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]: # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead comm_backends = str(dist.get_backend(group=group)) # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl") - gather_device = "cpu" if "cpu" in comm_backends else torch.accelerator.current_accelerator() + if "cpu" in comm_backends: + gather_device = "cpu" + else: + gather_device = torch.device(get_device()) + + if gather_device.type == "cpu": + raise RuntimeError("No suitable accelerator found.") + gathered_sizes = [torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)] dist.all_gather( gathered_sizes,