Skip to content

Commit e0dfd7b

Browse files
muellerzrSunMarcArthurZuckeramyeroberts
authored
Speedup model init on CPU (by 10x+ for llama-3-8B as one example) (#31771)
* 1,100%! * Clean * Don't touch DS * Experiment with dtype allocation * skip test_load_save_without_tied_weights test * A little faster * Include proper upscaling? * Fixup tests * Potentially skip? * Let's see if this fixes git history * Maintain new dtype * Fin * Rm hook idea for now * New approach, see what breaks * stage * Clean * Stash * Should be fin now, just need to mark failing models * Clean up * Simplify * Deal with weird models * Enc/Dec * Skip w/ reason * Adjust test * Fix test * one more test * Keep experimenting * Fix ref * TO REMOVE: testing feedback CI * Right push * Update tests/utils/test_modeling_utils.py Co-authored-by: Arthur <[email protected]> * disable * Add new func * Test nits from Amy * Update src/transformers/modeling_utils.py Co-authored-by: amyeroberts <[email protected]> * Adjust comment * Adjust comment on skip * make private * Fin * Should be a not flag * Clarify and rename test --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Arthur <[email protected]> Co-authored-by: amyeroberts <[email protected]>
1 parent 03a3bec commit e0dfd7b

File tree

17 files changed

+181
-21
lines changed

17 files changed

+181
-21
lines changed

docs/source/en/main_classes/model.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ for text generation, [`~generation.GenerationMixin`] (for the PyTorch models),
4040
- push_to_hub
4141
- all
4242

43+
Custom models should also include a `_supports_assign_param_buffer`, which determines if superfast init can apply
44+
on the particular model. Signs that your model needs this are if `test_save_and_load_from_pretrained` fails. If so,
45+
set this to `False`.
46+
4347
## ModuleUtilsMixin
4448

4549
[[autodoc]] modeling_utils.ModuleUtilsMixin

src/transformers/modeling_utils.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,32 @@ def dtype_byte_size(dtype):
338338
return bit_size // 8
339339

340340

341+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
342+
"""
343+
Checks if `model_to_load` supports param buffer assignment (such
344+
as when loading in empty weights) by first checking
345+
if the model explicitly disables it, then by ensuring that the state dict keys
346+
are a subset of the model's parameters.
347+
"""
348+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
349+
return False
350+
351+
# Some models explicitly do not support param buffer assignment
352+
if not getattr(model_to_load, "_supports_param_buffer_assignment", False):
353+
logger.debug(
354+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
355+
)
356+
return False
357+
358+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
359+
first_key = list(model_to_load.state_dict().keys())[0]
360+
if start_prefix + first_key in state_dict:
361+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
362+
363+
# For cases when the `state_dict` doesn't contain real weights to the model (`test_model_weights_reload_no_missing_tied_weights`)
364+
return False
365+
366+
341367
def shard_checkpoint(
342368
state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME
343369
):
@@ -657,7 +683,7 @@ def _find_identical(tensors: List[Set[str]], state_dict: Dict[str, torch.Tensor]
657683
return shared_tensors, identical
658684

659685

660-
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
686+
def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, assign_to_params_buffers=False):
661687
# Convert old format to new format if needed from a PyTorch state_dict
662688
old_keys = []
663689
new_keys = []
@@ -685,8 +711,10 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
685711

686712
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
687713
# so we need to apply the function recursively.
688-
def load(module: nn.Module, state_dict, prefix=""):
714+
def load(module: nn.Module, state_dict, prefix="", assign_to_params_buffers=False):
689715
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
716+
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
717+
690718
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
691719
# Parameters of module and children will start with prefix. We can exit early if there are none in this
692720
# state_dict
@@ -710,9 +738,9 @@ def load(module: nn.Module, state_dict, prefix=""):
710738

711739
for name, child in module._modules.items():
712740
if child is not None:
713-
load(child, state_dict, prefix + name + ".")
741+
load(child, state_dict, prefix + name + ".", assign_to_params_buffers)
714742

715-
load(model_to_load, state_dict, prefix=start_prefix)
743+
load(model_to_load, state_dict, prefix=start_prefix, assign_to_params_buffers=assign_to_params_buffers)
716744
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
717745
# it's safe to delete it.
718746
del state_dict
@@ -2852,6 +2880,10 @@ def from_pretrained(
28522880
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
28532881
weights are discarded.
28542882
2883+
If model weights are the same precision as the base model (and is a supported model), weights will be lazily loaded
2884+
in using the `meta` device and brought into memory once an input is passed through that layer regardless of
2885+
`low_cpu_mem_usage`.
2886+
28552887
Parameters:
28562888
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
28572889
Can be either:
@@ -2952,7 +2984,13 @@ def from_pretrained(
29522984
29532985
low_cpu_mem_usage(`bool`, *optional*):
29542986
Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
2987+
Generally should be combined with a `device_map` (such as `"auto"`) for best results.
29552988
This is an experimental feature and a subject to change at any moment.
2989+
</Tip>
2990+
If the model weights are in the same precision as the model loaded in, `low_cpu_mem_usage` (without
2991+
`device_map`) is redundant and will not provide any benefit in regards to CPU memory usage. However,
2992+
this should still be enabled if you are passing in a `device_map`.
2993+
</Tip>
29562994
torch_dtype (`str` or `torch.dtype`, *optional*):
29572995
Override the default `torch.dtype` and load the model under a specific `dtype`. The different options
29582996
are:
@@ -4018,6 +4056,7 @@ def _fix_key(key):
40184056

40194057
missing_keys = sorted(set(expected_keys) - set(loaded_keys))
40204058
unexpected_keys = set(loaded_keys) - set(expected_keys)
4059+
40214060
# Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model
40224061
# buffers
40234062
model_buffers = {n for n, _ in model.named_buffers()}
@@ -4252,7 +4291,12 @@ def _find_mismatched_keys(
42524291
)
42534292
else:
42544293
# Sharded checkpoint or whole but low_cpu_mem_usage==True
4255-
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
4294+
assign_to_params_buffers = check_support_param_buffer_assignment(
4295+
model_to_load, state_dict, start_prefix
4296+
)
4297+
error_msgs = _load_state_dict_into_model(
4298+
model_to_load, state_dict, start_prefix, assign_to_params_buffers
4299+
)
42564300

42574301
else:
42584302
# This should always be a list but, just to be sure.
@@ -4280,6 +4324,7 @@ def _find_mismatched_keys(
42804324

42814325
if len(resolved_archive_file) > 1:
42824326
resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards")
4327+
assign_to_params_buffers = None
42834328
for shard_file in resolved_archive_file:
42844329
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
42854330
if shard_file in disk_only_shard_files:
@@ -4323,7 +4368,14 @@ def _find_mismatched_keys(
43234368
)
43244369
error_msgs += new_error_msgs
43254370
else:
4326-
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
4371+
# Sharded checkpoint or whole but low_cpu_mem_usage==True
4372+
if assign_to_params_buffers is None:
4373+
assign_to_params_buffers = check_support_param_buffer_assignment(
4374+
model_to_load, state_dict, start_prefix
4375+
)
4376+
error_msgs += _load_state_dict_into_model(
4377+
model_to_load, state_dict, start_prefix, assign_to_params_buffers
4378+
)
43274379

43284380
# force memory release
43294381
del state_dict

src/transformers/models/encoder_decoder/modeling_encoder_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class EncoderDecoderModel(PreTrainedModel):
178178
base_model_prefix = "encoder_decoder"
179179
main_input_name = "input_ids"
180180
supports_gradient_checkpointing = True
181+
_supports_param_buffer_assignment = False
181182

182183
def __init__(
183184
self,

src/transformers/models/lxmert/modeling_lxmert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ class LxmertPreTrainedModel(PreTrainedModel):
773773
config_class = LxmertConfig
774774
load_tf_weights = load_tf_weights_in_lxmert
775775
base_model_prefix = "lxmert"
776+
_supports_param_buffer_assignment = False
776777

777778
def _init_weights(self, module):
778779
"""Initialize the weights"""

src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
159159
base_model_prefix = "vision_encoder_decoder"
160160
main_input_name = "pixel_values"
161161
supports_gradient_checkpointing = True
162+
_supports_param_buffer_assignment = False
162163

163164
def __init__(
164165
self,

tests/models/bart/test_modeling_bart.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,12 @@ def test_generate_fp16(self):
512512
model.generate(input_ids, attention_mask=attention_mask)
513513
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
514514

515+
@unittest.skip(
516+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
517+
)
518+
def test_load_save_without_tied_weights(self):
519+
pass
520+
515521

516522
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
517523
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""

tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,12 @@ def test_for_change_to_full_attn(self):
476476

477477
self.assertTrue(torch.allclose(outputs1, outputs2, atol=1e-5))
478478

479+
@unittest.skip(
480+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
481+
)
482+
def test_load_save_without_tied_weights(self):
483+
pass
484+
479485

480486
@require_torch
481487
@require_sentencepiece

tests/models/longt5/test_modeling_longt5.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,12 @@ def _check_encoder_attention_for_generate(self, attentions, batch_size, config,
758758
[encoder_expected_shape] * len(attentions),
759759
)
760760

761+
@unittest.skip(
762+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
763+
)
764+
def test_load_save_without_tied_weights(self):
765+
pass
766+
761767

762768
@require_torch
763769
class LongT5TGlobalModelTest(LongT5ModelTest):
@@ -1097,6 +1103,12 @@ def test_attention_outputs(self):
10971103
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
10981104
)
10991105

1106+
@unittest.skip(
1107+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
1108+
)
1109+
def test_load_save_without_tied_weights(self):
1110+
pass
1111+
11001112

11011113
class LongT5EncoderOnlyTGlobalModelTest(LongT5EncoderOnlyModelTest):
11021114
def setUp(self):

tests/models/lxmert/test_modeling_lxmert.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,12 @@ def test_save_load_low_cpu_mem_usage_checkpoints(self):
778778
def test_save_load_low_cpu_mem_usage_no_safetensors(self):
779779
pass
780780

781+
@unittest.skip(
782+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
783+
)
784+
def test_load_save_without_tied_weights(self):
785+
pass
786+
781787

782788
@require_torch
783789
class LxmertModelIntegrationTest(unittest.TestCase):

tests/models/m2m_100/test_modeling_m2m_100.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ def test_generate_fp16(self):
331331
model.generate(input_ids, attention_mask=attention_mask)
332332
model.generate(num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
333333

334+
@unittest.skip(
335+
reason="This architecure has tied weights by default and there is no way to remove it, check: https://github.com/huggingface/transformers/pull/31771#issuecomment-2210915245"
336+
)
337+
def test_load_save_without_tied_weights(self):
338+
pass
339+
334340

335341
def _long_tensor(tok_lst):
336342
return torch.tensor(tok_lst, dtype=torch.long, device=torch_device)

0 commit comments

Comments
 (0)