-
Notifications
You must be signed in to change notification settings - Fork 31.7k
fix torch_dtype, contiguous, and load_state_dict regression #36512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9f837f5
fix regression
SunMarc 1feb333
fix param
SunMarc c50062c
Merge branch 'main' into fix-dtype-and-contiguous-regression
SunMarc a5bec40
fix load_state_dict
SunMarc 56fdc9d
Merge remote-tracking branch 'upstream/fix-dtype-and-contiguous-regre…
SunMarc 61f4bc7
style
SunMarc 2ddc6ad
better fix for module
SunMarc bb77861
fix tests
SunMarc 5c76fbd
quick fix for now
SunMarc 63915e1
rm print
SunMarc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,6 +67,7 @@ | |
| translate_to_torch_parallel_style, | ||
| ) | ||
| from .quantizers import AutoHfQuantizer, HfQuantizer | ||
| from .quantizers.quantizers_utils import get_module_from_name | ||
| from .safetensors_conversion import auto_conversion | ||
| from .utils import ( | ||
| ACCELERATE_MIN_VERSION, | ||
|
|
@@ -536,11 +537,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): | |
| def load_state_dict( | ||
| checkpoint_file: Union[str, os.PathLike], | ||
| is_quantized: bool = False, | ||
| map_location: Optional[Union[str, torch.device]] = "meta", | ||
| map_location: Optional[Union[str, torch.device]] = "cpu", | ||
| weights_only: bool = True, | ||
| ): | ||
| """ | ||
| Reads a `safetensor` or a `.bin` checkpoint file into `meta` if requested. | ||
| Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default. | ||
| """ | ||
| if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): | ||
| with safe_open(checkpoint_file, framework="pt") as f: | ||
|
|
@@ -771,6 +772,7 @@ def _load_state_dict_into_meta_model( | |
| unexpected_keys=None, # passing `unexpected` for cleanup from quantization items | ||
| device_mesh=None, | ||
| shard_file=None, | ||
| weights_only=True, | ||
| ): | ||
| """ | ||
| This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its | ||
|
|
@@ -800,7 +802,15 @@ def _load_state_dict_into_meta_model( | |
| if shard_file.endswith(".safetensors"): | ||
| file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) | ||
| else: | ||
| bin_state_dict = load_state_dict(shard_file, map_location="cpu") | ||
| map_location = "cpu" | ||
| if ( | ||
| device_map is not None | ||
| and hf_quantizer is not None | ||
| and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
| and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
| ): | ||
| map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
| bin_state_dict = load_state_dict(shard_file, map_location=map_location, weights_only=weights_only) | ||
|
|
||
| error_msgs = [] | ||
|
|
||
|
|
@@ -822,23 +832,36 @@ def _load_state_dict_into_meta_model( | |
| if shard_file.endswith(".safetensors") | ||
| else bin_state_dict[serialized_param_name] | ||
| ) | ||
|
|
||
| # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which | ||
| # uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model. | ||
| # Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29 | ||
|
|
||
| old_param = model | ||
| splits = fixed_param_name.split(".") | ||
| for split in splits: | ||
| # We shouldn't hit the default value unless for quant methods like hqq that modifies expected_keys. | ||
| old_param = getattr(old_param, split, None) | ||
| if old_param is None: | ||
| break | ||
|
|
||
| if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)): | ||
| old_param = None | ||
|
|
||
| # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params | ||
| # in int/uint/bool and not cast them. | ||
| param_casting_dtype = None | ||
| is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn | ||
|
|
||
| if dtype is not None and empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: | ||
| if ( | ||
| keep_in_fp32_modules is not None | ||
| and keep_in_fp32_modules.search(fixed_param_name) | ||
| and dtype == torch.float16 | ||
| ): | ||
| if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: | ||
| if keep_in_fp32_modules is not None and keep_in_fp32_modules.search(fixed_param_name): | ||
| param_casting_dtype = torch.float32 | ||
| else: | ||
| elif dtype is not None: | ||
| param_casting_dtype = dtype | ||
| elif old_param is not None: | ||
| param_casting_dtype = old_param.dtype | ||
|
|
||
| if device_mesh is not None: # In this case, the param is already on the correct device! | ||
| module_to_tp, param_type = find_submodule_and_param_name(model, fixed_param_name) | ||
| module_to_tp, param_type = get_module_from_name(model, fixed_param_name) | ||
| current_module_plan = None | ||
| full_tp_plan_ = "|".join(full_tp_plan.keys()).replace("*", "[0-9]+") | ||
| if plan := re.search(full_tp_plan_, fixed_param_name): | ||
|
|
@@ -859,8 +882,10 @@ def _load_state_dict_into_meta_model( | |
| else: | ||
| param = param[rank * (row // device_mesh.size()) : (rank + 1) * (row // device_mesh.size()), :] | ||
| shard = Shard(0) | ||
| if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: | ||
| if param_casting_dtype is not None: | ||
| param = param.to(param_casting_dtype) | ||
| if old_param.is_contiguous(): | ||
| param = param.contiguous() | ||
| local_parameter = DTensor.from_local( | ||
| param, | ||
| device_mesh=device_mesh, | ||
|
|
@@ -873,9 +898,18 @@ def _load_state_dict_into_meta_model( | |
| output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) | ||
| distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) | ||
| else: | ||
| module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True) | ||
| param = param[:] | ||
| if old_param is not None and old_param.is_contiguous(): | ||
| param = param.contiguous() | ||
| module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) | ||
|
|
||
| else: | ||
| param = param[:] | ||
| if param_casting_dtype is not None: | ||
| param = param.to(param_casting_dtype) | ||
| if old_param is not None and old_param.is_contiguous(): | ||
| param = param.contiguous() | ||
|
|
||
| if device_map is None: | ||
| param_device = "cpu" | ||
| else: | ||
|
|
@@ -887,9 +921,9 @@ def _load_state_dict_into_meta_model( | |
|
|
||
| if param_device == "disk": | ||
| if not is_safetensors: | ||
| offload_index = offload_weight(param[:], fixed_param_name, offload_folder, offload_index) | ||
| offload_index = offload_weight(param, fixed_param_name, offload_folder, offload_index) | ||
| elif param_device == "cpu" and state_dict_index is not None: | ||
| state_dict_index = offload_weight(param[:], fixed_param_name, state_dict_folder, state_dict_index) | ||
| state_dict_index = offload_weight(param, fixed_param_name, state_dict_folder, state_dict_index) | ||
| elif ( | ||
| not is_quantized | ||
| or (not hf_quantizer.requires_parameters_quantization) | ||
|
|
@@ -906,23 +940,25 @@ def _load_state_dict_into_meta_model( | |
| ): | ||
| if is_fsdp_enabled(): | ||
| param_device = "cpu" if is_local_dist_rank_0() else "meta" | ||
| module, param_type = find_submodule_and_param_name(model, fixed_param_name) | ||
| if param_casting_dtype is not None and param_casting_dtype != empty_param.dtype: | ||
| param = param[:].to(param_casting_dtype) | ||
| module, param_type = get_module_from_name(model, fixed_param_name) | ||
| print(model) | ||
| print(fixed_param_name) | ||
| print(param) | ||
| print(module) | ||
| module.load_state_dict( | ||
| {param_type: param[:].to(param_device)}, | ||
| {param_type: param.to(param_device)}, | ||
| strict=False, | ||
| assign=True, | ||
| ) | ||
| else: | ||
| hf_quantizer.create_quantized_param( | ||
| model, param[:], fixed_param_name, param_device, state_dict, unexpected_keys | ||
| model, param, fixed_param_name, param_device, state_dict, unexpected_keys | ||
| ) | ||
| # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU | ||
| # and then cast it to CPU to avoid excessive memory usage on each GPU | ||
| # in comparison to the sharded model across GPUs. | ||
| if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): | ||
| module, param_type = find_submodule_and_param_name(model, fixed_param_name) | ||
| module, param_type = get_module_from_name(model, fixed_param_name) | ||
| value = getattr(module, param_type) | ||
| param_to = "cpu" | ||
| if is_fsdp_enabled() and not is_local_dist_rank_0(): | ||
|
|
@@ -4203,7 +4239,9 @@ def from_pretrained( | |
| elif not is_sharded: | ||
| torch_dtype = get_state_dict_dtype(state_dict) | ||
| else: | ||
| one_state_dict = load_state_dict(resolved_archive_file[0], weights_only=weights_only) | ||
| one_state_dict = load_state_dict( | ||
| resolved_archive_file[0], map_location="meta", weights_only=weights_only | ||
| ) | ||
| torch_dtype = get_state_dict_dtype(one_state_dict) | ||
| del one_state_dict # free CPU memory | ||
| logger.info( | ||
|
|
@@ -4848,7 +4886,7 @@ def _load_pretrained_model( | |
| else: | ||
| folder = None | ||
|
|
||
| model.expected_keys = expected_keys | ||
| model_to_load.expected_keys = expected_keys | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since we are calling _fix_state_dict_keys_on_load on model_to_load |
||
| if device_map is not None: | ||
| expanded_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) | ||
| if hf_quantizer is None: | ||
|
|
@@ -4907,6 +4945,7 @@ def _load_pretrained_model( | |
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| resolved_archive_file=resolved_archive_file, | ||
| weights_only=weights_only, | ||
| ) | ||
| else: | ||
| # We need to read the state dict as it is meta otherwise | ||
|
|
@@ -4957,16 +4996,8 @@ def _load_pretrained_model( | |
| # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. | ||
| if shard_file in disk_only_shard_files: | ||
| continue | ||
| map_location = None | ||
| if ( | ||
| device_map is not None | ||
| and hf_quantizer is not None | ||
| and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO | ||
| and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"] | ||
| ): | ||
| map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0]) | ||
| state_dict = load_state_dict( | ||
| shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only | ||
| shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only | ||
| ) | ||
|
|
||
| # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not | ||
|
|
@@ -5006,6 +5037,7 @@ def _load_pretrained_model( | |
| unexpected_keys=unexpected_keys, | ||
| device_mesh=device_mesh, | ||
| shard_file=shard_file, | ||
| weights_only=weights_only, | ||
| ) | ||
| error_msgs += new_error_msgs | ||
| else: | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SunMarc why is
param = param[:]needed?edit - ok, this is for safetensors. Unfortuantely safetensors
get_slicedoes not play well with 0-dim tensors :( huggingface/safetensors#380