Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,6 @@
from safetensors.torch import save_file as safe_save_file


if is_deepspeed_available():
import deepspeed

if is_kernels_available():
from kernels import get_kernel

Expand Down Expand Up @@ -2007,6 +2004,8 @@ def _from_config(cls, config, **kwargs):
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
# this immediately partitions the model across all gpus, to avoid the overhead in time
# and memory copying it on CPU or each GPU first
import deepspeed

init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
with ContextManagers(init_contexts):
model = cls(config, **kwargs)
Expand Down Expand Up @@ -2702,6 +2701,8 @@ def resize_token_embeddings(
# Since we are basically reusing the same old embeddings with new weight values, gathering is required
is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
if is_deepspeed_zero3_enabled() and not is_quantized:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am assuming that we do not need an additional is_deepspeed_available() check, but let me know if we do!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we don't need one ! if is_deepspeed_zero3_enabled returns True, deepspeed will be installed

import deepspeed

with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
vocab_size = model_embeds.weight.shape[0]
else:
Expand Down Expand Up @@ -2732,6 +2733,8 @@ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean
# Update new_num_tokens with the actual size of new_embeddings
if pad_to_multiple_of is not None:
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
new_num_tokens = new_embeddings.weight.shape[0]
else:
Expand Down Expand Up @@ -2799,6 +2802,9 @@ def _get_resized_embeddings(
`new_num_tokens` is `None`
"""

if is_deepspeed_available():
import deepspeed

if pad_to_multiple_of is not None:
if not isinstance(pad_to_multiple_of, int):
raise ValueError(
Expand Down Expand Up @@ -2941,6 +2947,10 @@ def _get_resized_lm_head(
`torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
`None`
"""

if is_deepspeed_available():
import deepspeed

if new_num_tokens is None:
return old_lm_head

Expand Down Expand Up @@ -3762,6 +3772,8 @@ def float(self, *args):
@classmethod
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
if is_deepspeed_zero3_enabled():
import deepspeed

init_contexts = [no_init_weights()]
# We cannot initialize the model on meta device with deepspeed when not quantized
if not is_quantized and not _is_ds_init_called:
Expand Down Expand Up @@ -5345,6 +5357,8 @@ def _initialize_missing_keys(
not_initialized_submodules = dict(self.named_modules())
# This will only initialize submodules that are not marked as initialized by the line above.
if is_deepspeed_zero3_enabled() and not is_quantized:
import deepspeed

not_initialized_parameters = list(
set(
itertools.chain.from_iterable(
Expand Down