Skip to content

Commit 3f8941e

Browse files
committed
Upcast parameters only if requires_grad
1 parent 9a20895 commit 3f8941e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

src/accelerate/utils/fsdp_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,13 +707,17 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
707707
# Set it to None if it doesn't exist and do the upcast always
708708
model_dtype = getattr(model, "dtype", None)
709709
if accelerator.mixed_precision != "no" and (model_dtype is None or model_dtype != torch.float32):
710-
# We upcast the model according to `deepspeed`'s implementation
710+
# We upcast the trainable parameters according to `deepspeed`'s implementation
711711
# More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section
712-
model = model.to(torch.float32)
713-
if accelerator.is_main_process:
714-
# TODO(siro1): Add a warning for each parameter that was upcasted
712+
upcasted_params = []
713+
for name, param in model.parameters():
714+
if param.requires_grad and param.dtype != torch.float32:
715+
upcasted_params.append(name)
716+
param.data = param.data.to(torch.float32)
717+
if accelerator.is_main_process and upcasted_params:
715718
warnings.warn(
716-
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints."
719+
"FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints. "
720+
f"This effects {len(upcasted_params)} parameters: {upcasted_params[:10]}..."
717721
)
718722
return model
719723

0 commit comments

Comments
 (0)