-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Fix slicing for 0-dim param #36580
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix slicing for 0-dim param #36580
Changes from 2 commits
f8d1eec
9ea253c
cb8388f
f53322e
01cd14e
6f0c77e
51fa632
be6f01a
527676c
87dda3c
8022a93
00e73cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -799,9 +799,10 @@ def _load_state_dict_into_meta_model( | |
| for submodule in model.modules(): | ||
| full_tp_plan.update(getattr(submodule, "_tp_plan", {})) | ||
|
|
||
| is_safetensors_shard = shard_file.endswith(".safetensors") | ||
| file_pointer = None | ||
| bin_state_dict = None | ||
| if shard_file.endswith(".safetensors"): | ||
| if is_safetensors_shard: | ||
| file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) | ||
| else: | ||
| map_location = "cpu" | ||
|
|
@@ -831,7 +832,7 @@ def _load_state_dict_into_meta_model( | |
| # we need to use serialized_param_name as file pointer is untouched | ||
| param = ( | ||
| file_pointer.get_slice(serialized_param_name) | ||
| if shard_file.endswith(".safetensors") | ||
| if is_safetensors_shard | ||
| else bin_state_dict[serialized_param_name] | ||
| ) | ||
|
|
||
|
|
@@ -900,13 +901,15 @@ def _load_state_dict_into_meta_model( | |
| output_fn = partial(tp_layer._prepare_output_fn, tp_layer.output_layouts, tp_layer.use_local_output) | ||
| distribute_module(module_to_tp, device_mesh, None, input_fn, output_fn) | ||
| else: | ||
| param = param[:] | ||
| if is_safetensors_shard: | ||
| param = param[:] | ||
| if old_param is not None and old_param.is_contiguous(): | ||
| param = param.contiguous() | ||
| module_to_tp.load_state_dict({param_type: param}, strict=False, assign=True) | ||
|
|
||
| else: | ||
| param = param[:] | ||
| if is_safetensors_shard: | ||
| param = param[:] | ||
|
||
| if param_casting_dtype is not None: | ||
| param = param.to(param_casting_dtype) | ||
| if old_param is not None and old_param.is_contiguous(): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.