Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3598,7 +3598,8 @@ def from_pretrained(
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
index = None if device_type == "cpu" else torch.distributed.get_rank() % device_module.device_count()
tp_device = torch.device(device_type, index)
# This is the easiest way to dispatch to the current process device
device_map = tp_device

Expand Down