Skip to content

fix FSDP2 test case failure on XPU#3771

Merged
S1ro1 merged 2 commits intohuggingface:mainfrom
yao-matrix:issue-509
Sep 12, 2025
Merged

fix FSDP2 test case failure on XPU#3771
S1ro1 merged 2 commits intohuggingface:mainfrom
yao-matrix:issue-509

Conversation

@yao-matrix
Copy link
Contributor

bug

when run pytest -rA tests/fsdp/test_fsdp.py::FSDP2IntegrationTest::test_checkpointing on XPU, fail message as

stderr: [rank1]: main()
stderr: [rank1]: File "/opt/venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/test_checkpointing.py", line 265, in main
stderr: [rank1]: training_function(config, args)
stderr: [rank1]: File "/opt/venv/lib/python3.12/site-packages/accelerate/test_utils/scripts/external_deps/test_checkpointing.py", line 171, in training_function
stderr: [rank1]: accelerator.load_state(args.resume_from_checkpoint)
stderr: [rank1]: File "/opt/venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 3714, in load_state
stderr: [rank1]: self.scaler._lazy_init_scale_growth_tracker(self.scaler._device)
stderr: [rank1]: File "/opt/venv/lib/python3.12/site-packages/torch/amp/grad_scaler.py", line 171, in _lazy_init_scale_growth_tracker
stderr: [rank1]: assert self._growth_tracker is None, "_growth_tracker initialized before _scale"
stderr: [rank1]: ^^^^^^^^^^^^^^^^^^^^
stderr: [rank1]: AttributeError: 'GradScaler' object has no attribute '_growth_tracker'. Did you mean: '_get_growth_tracker'?

root cause

when instantiate GradScaler in get_fsdp2_grad_scaler, the device is not passed, and will not because device is not in GradScalerKwargs either. In this case, PyTorch will defaultly set device to cuda, and lead to this issue.

fix

explicitly pass device in.

result

pass

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
@yao-matrix
Copy link
Contributor Author

@SunMarc , pls help review, thx

Signed-off-by: YAO Matrix <matrix.yao@intel.com>
if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
raise ValueError(
f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prior error message is misleading.

Copy link
Contributor

@S1ro1 S1ro1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! LGTM

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@S1ro1 S1ro1 merged commit 45959d7 into huggingface:main Sep 12, 2025
25 checks passed
@yao-matrix yao-matrix deleted the issue-509 branch September 15, 2025 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants