Skip to content
Merged
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,8 @@ def _load_state_dict_into_meta_model(
"""
tensor_device = "cpu"
if device_map is not None and device_map.get("", None) is not None:
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map[""] not in ("cpu", torch.device("cpu")):
tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
if device_map is not None:
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])

Expand Down Expand Up @@ -801,14 +802,15 @@ def _load_state_dict_into_meta_model(
)

if device_mesh is not None: # In this case, the param is already on the correct device!
rank = tensor_device if isinstance(tensor_device, int) else torch.distributed.get_rank()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed

shard_and_distribute_module(
model,
param,
empty_param,
param_name,
casting_dtype,
to_contiguous,
tensor_device, # the rank
rank, # the rank
device_mesh,
)
else:
Expand Down Expand Up @@ -4095,24 +4097,31 @@ def from_pretrained(
if tp_plan is not None:
if not is_torch_greater_or_equal("2.5"):
raise EnvironmentError("tensor parallel is only supported for `torch>=2.5`.")

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type

if not torch.distributed.is_initialized():
try:
logger.warning("Tensor Parallel requires torch.distributed to be initialized first.")
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
if device_type == "cuda":
torch.distributed.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
elif device_type == "cpu":
cpu_backend = "ccl" if int(os.environ.get("CCL_WORKER_COUNT", 0)) else "gloo"
torch.distributed.init_process_group(cpu_backend, rank=rank, world_size=world_size)
except Exception as e:
raise EnvironmentError(
"We tried to initialize torch.distributed for you, but it failed, make"
"sure you init torch distributed in your script to use `tp_plan='auto'`"
) from e

# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
# Get device with index assuming equal number of devices per host
tp_device = torch.device(device_type, torch.distributed.get_rank() % device_module.device_count())
index = None if device_type == "cpu" else torch.distributed.get_rank() % device_module.device_count()
tp_device = torch.device(device_type, index)
# This is the easiest way to dispatch to the current process device
device_map = tp_device

Expand Down