Skip to content

Conversation

@SunMarc
Copy link
Member

@SunMarc SunMarc commented Mar 3, 2025

What does this PR do ?

This PR fixes the regression we have from #36335:

  • the model should be in fp32 if no torch_dtype is passed
  • param should be contiguous if the old_param is
  • change back load_state_dict map_location default to "cpu" + couples of small fixes -> fixes CI !
  • switch from get_submodule_get_module_from_name to get_module_name is it also handles the case where the param name reference to a direct param from the model.

Should fixes tests in peft and in the quantization CI

Perf with TP 2 GPUS:

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.34it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:03<00:00,  1.33it/s]
Model loading time: 5.91 seconds
Model loading time: 5.77 seconds
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                 model_load        63.70%        3.572s       100.00%        5.607s        5.607s       0.000us         0.00%     773.210ms     773.210ms           8 b     -14.96 Gb       7.97 Gb     -29.92 Gb             1
                                aten::copy_        14.63%     820.125ms        45.88%        2.572s       3.593ms     773.210ms       100.00%     949.505ms       1.326ms           0 b      -2.25 Gb           0 b           0 b           716
                                   aten::to         0.03%       1.862ms        29.27%        1.641s       1.124ms       0.000us         0.00%     773.210ms     529.596us         272 b           0 b       7.97 Gb           0 b          1460
                             aten::_to_copy         0.08%       4.381ms        29.23%        1.639s       2.788ms       0.000us         0.00%     773.210ms       1.315ms         272 b           0 b       7.97 Gb           0 b           588
                            cudaMemcpyAsync        14.14%     792.725ms        14.14%     792.725ms       2.715ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292
                           aten::contiguous         0.01%     321.399us        13.54%     759.069ms      11.860ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64
                                aten::clone         0.02%       1.045ms        13.53%     758.747ms      11.855ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64
                             cudaMemGetInfo         4.80%     269.310ms         4.87%     273.148ms     273.148ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1
                               aten::detach         0.14%       7.882ms         1.13%      63.286ms      19.062us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3320
                                     detach         0.86%      48.257ms         0.99%      55.334ms      19.504us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2837
                           _FromTorchTensor         0.38%      21.113ms         0.44%      24.477ms     108.785us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225
                                aten::empty         0.16%       8.994ms         0.24%      13.663ms       8.971us       0.000us         0.00%       0.000us       0.000us      47.12 Gb      47.12 Gb      29.92 Gb      29.92 Gb          1523
                      cudaStreamSynchronize         0.23%      12.651ms         0.23%      12.651ms      43.324us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292
                        aten::empty_strided         0.17%       9.539ms         0.18%       9.957ms      16.297us       0.000us         0.00%       0.000us       0.000us         272 b         272 b       7.97 Gb       7.97 Gb           611
                                 aten::set_         0.14%       7.715ms         0.14%       7.715ms       6.594us       0.000us         0.00%       0.000us       0.000us     -29.92 Gb     -29.92 Gb           0 b           0 b          1170
                                 aten::view         0.10%       5.355ms         0.10%       5.355ms       3.852us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          1390
                                 cudaMalloc         0.09%       4.963ms         0.09%       4.963ms       2.481ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2
                               aten::select         0.07%       3.705ms         0.08%       4.329ms       7.426us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           583
                                   Resource         0.07%       3.731ms         0.07%       3.731ms     932.860us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             4
                                 aten::item         0.02%       1.049ms         0.06%       3.617ms       6.214us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                              aten::view_as         0.02%       1.331ms         0.06%       3.364ms      14.951us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225
                                aten::slice         0.05%       3.006ms         0.06%       3.346ms       6.484us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           516
                  aten::_local_scalar_dense         0.05%       2.568ms         0.05%       2.568ms       4.412us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                              aten::reshape         0.02%       1.028ms         0.04%       2.218ms       3.811us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                           aten::as_strided         0.02%       1.010ms         0.02%       1.010ms       0.868us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          1163
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 5.607s
Self CUDA time total: 773.210ms

Loading took 5.909772157669067 seconds
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                 model_load        66.96%        3.596s       100.00%        5.371s        5.371s       0.000us         0.00%     769.832ms     769.832ms           8 b     -14.96 Gb       7.97 Gb     -29.92 Gb             1
                                aten::copy_        15.43%     828.993ms        48.13%        2.585s       3.611ms     769.832ms       100.00%     944.418ms       1.319ms           0 b      -2.25 Gb           0 b           0 b           716
                                   aten::to         0.04%       1.920ms        30.65%        1.646s       1.128ms       0.000us         0.00%     769.832ms     527.282us         272 b           0 b       7.97 Gb           0 b          1460
                             aten::_to_copy         0.08%       4.544ms        30.61%        1.644s       2.796ms       0.000us         0.00%     769.832ms       1.309ms         272 b           0 b       7.97 Gb           0 b           588
                            cudaMemcpyAsync        14.69%     788.871ms        14.69%     788.871ms       2.702ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292
                           aten::contiguous         0.01%     310.266us        14.30%     768.117ms      12.002ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64
                                aten::clone         0.01%     695.911us        14.29%     767.807ms      11.997ms       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64
                               aten::detach         0.15%       8.003ms         1.14%      61.366ms      18.484us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          3320
                                     detach         0.86%      46.253ms         0.99%      53.289ms      18.784us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          2837
                           _FromTorchTensor         0.39%      21.086ms         0.46%      24.558ms     109.145us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225
                                aten::empty         0.18%       9.418ms         0.26%      13.991ms       9.186us       0.000us         0.00%       0.000us       0.000us      47.12 Gb      47.12 Gb      29.92 Gb      29.92 Gb          1523
                      cudaStreamSynchronize         0.25%      13.234ms         0.25%      13.234ms      45.322us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           292
                        aten::empty_strided         0.18%       9.468ms         0.18%       9.875ms      16.162us       0.000us         0.00%       0.000us       0.000us         272 b         272 b       7.97 Gb       7.97 Gb           611
                             cudaMemGetInfo         0.16%       8.779ms         0.16%       8.779ms       8.779ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             1
                                 aten::set_         0.16%       8.561ms         0.16%       8.561ms       7.317us       0.000us         0.00%       0.000us       0.000us     -29.92 Gb     -29.92 Gb           0 b           0 b          1170
                                 aten::view         0.10%       5.547ms         0.10%       5.547ms       3.990us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          1390
                                 cudaMalloc         0.09%       4.880ms         0.09%       4.880ms       2.440ms       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b             2
                               aten::select         0.07%       3.773ms         0.08%       4.339ms       7.442us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           583
                              aten::view_as         0.02%       1.330ms         0.06%       3.472ms      15.430us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           225
                                aten::slice         0.06%       2.984ms         0.06%       3.338ms       6.470us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           516
                                 aten::item         0.02%       1.074ms         0.05%       2.607ms       4.479us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                              aten::reshape         0.02%       1.050ms         0.04%       2.227ms       3.826us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                  aten::_local_scalar_dense         0.03%       1.532ms         0.03%       1.532ms       2.633us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b           582
                           aten::as_strided         0.02%     964.845us         0.02%     964.845us       0.830us       0.000us         0.00%       0.000us       0.000us           0 b           0 b           0 b           0 b          1163
                           aten::empty_like         0.01%     290.848us         0.01%     688.124us      10.752us       0.000us         0.00%       0.000us       0.000us       2.25 Gb           0 b           0 b           0 b            64
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 5.371s
Self CUDA time total: 769.832ms

@github-actions
Copy link
Contributor

github-actions bot commented Mar 3, 2025

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@github-actions github-actions bot marked this pull request as draft March 3, 2025 14:18
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@SunMarc SunMarc marked this pull request as ready for review March 3, 2025 14:55
@SunMarc SunMarc requested a review from ArthurZucker March 3, 2025 14:55
@SunMarc SunMarc changed the title fix torch_dtype and contiguous regression fix torch_dtype, contiguous, and load_state_dict regression Mar 3, 2025
folder = None

model.expected_keys = expected_keys
model_to_load.expected_keys = expected_keys
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are calling _fix_state_dict_keys_on_load on model_to_load

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SunMarc SunMarc merged commit 0463901 into main Mar 3, 2025
24 checks passed
@SunMarc SunMarc deleted the fix-dtype-and-contiguous-regression branch March 3, 2025 17:35
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
…ace#36512)

* fix regression

* fix param

* fix load_state_dict

* style

* better fix for module

* fix tests

* quick fix for now

* rm print
garrett361 pushed a commit to garrett361/transformers that referenced this pull request Mar 4, 2025
…ace#36512)

* fix regression

* fix param

* fix load_state_dict

* style

* better fix for module

* fix tests

* quick fix for now

* rm print
distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn)
else:
module_to_tp.load_state_dict({param_type: param[:]}, strict=False, assign=True)
param = param[:]
Copy link
Contributor

@fxmarty-amd fxmarty-amd Mar 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc why is param = param[:] needed?

edit - ok, this is for safetensors. Unfortuantely safetensors get_slice does not play well with 0-dim tensors :( huggingface/safetensors#380

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants