Skip to content

Commit bc3d578

Browse files
SunMarcfxmarty-amd
andauthored
Fix slicing for 0-dim param (#36580)
* fix * switch to ellipsis instead * Add co-author Co-authored-by: fxmarty-amd <[email protected]> * Add co-author second try Co-authored-by: fxmarty-amd <[email protected]>
1 parent fbb18ce commit bc3d578

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def shard_and_distribute_module(
531531
param, empty_param, param_type, param_casting_dtype, is_contiguous, rank, device_mesh
532532
)
533533
else:
534-
param = param[:]
534+
param = param[...]
535535
if is_contiguous:
536536
param = param.contiguous()
537537

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ def _load_state_dict_into_meta_model(
795795
device_mesh,
796796
)
797797
else:
798-
param = param[:]
798+
param = param[...]
799799
if casting_dtype is not None:
800800
param = param.to(casting_dtype)
801801
if to_contiguous:

0 commit comments

Comments
 (0)