Skip to content

Fix: ensure consistent dtype and eval mode in pipeline save/load tests#13339

Open
YangKai0616 wants to merge 9 commits intohuggingface:mainfrom
YangKai0616:fix-consistent
Open

Fix: ensure consistent dtype and eval mode in pipeline save/load tests#13339
YangKai0616 wants to merge 9 commits intohuggingface:mainfrom
YangKai0616:fix-consistent

Conversation

@YangKai0616
Copy link
Copy Markdown
Contributor

Issue:

The test_save_load_float16 test has a hidden non-determinism problem:

  1. When the pipeline is constructed directly, its components default to training mode (e.g., text_encoder contains dropout layers). However, after loading with from_pretrained, the components are in eval mode.
  2. Certain buffers (such as blur_kernel in MotionConv2d) may have inconsistent dtypes during the save/load process, leading to inconsistent inference results.

Although the test usually "luckily" passes on CUDA (because dropout errors are diluted by numerical propagation), this is fundamentally a correctness issue in both the test design and the implementation.

Fix:

  1. Set the text_encoder component in the pipeline to eval mode, to be consistent with the handling in other test_save_loadXX tests such as test_save_load_local.
  2. In the forward pass, automatically cast the blur_kernel to the input dtype. This ensures consistency before and after save/load operations and avoids dtype mismatches. (blur_kernel is essentially a fixed mathematical constant, not a learnable parameter. It should adapt to the dtype of the computation stream, rather than forcing the computation stream's data to adapt to it.)

Hi @sayakpaul , please help review, thanks!

expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
x = x.to(expanded_kernel.dtype)
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
x = F.conv2d(x, expanded_kernel.to(x.dtype), padding=self.blur_padding, groups=self.in_channels)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this relevant?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blur_kernel is registered with persistent=False, so it's not saved into the checkpoint. After from_pretrained, it's always re-initialized in float32, while before saving it was float16 (due to .half()). This dtype mismatch causes the two inference runs to take different numerical paths.

We can print expanded_kernel.dtype in the test tests/pipelines/wan/test_wan_animate.py::WanAnimatePipelineFastTests::test_save_load_float16 — it outputs float16 before save and float32 after load.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then I propose we xfail that test under test suite of Wan Animate and tackle it in a separate PR. We might want to use modules_to_not_convert or something.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

In addition, regarding the modification of the blur_kernel, I think that since it is a fixed mathematical constant, it is reasonable to directly adjust the dtype to fit the computational flow. What do you think? If you agree, I will submit a new PR to modify this part and restore the xfail in the test.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably the call to module.named_parameters() below in test_save_load_float16 is too restrictive, as named_parameters() will not return buffers:

for name, param in module.named_parameters():
if any(
module_to_keep_in_fp32 in name.split(".")
for module_to_keep_in_fp32 in module._keep_in_fp32_modules
):
param.data = param.data.to(torch_device).to(torch.float32)
else:
param.data = param.data.to(torch_device).to(torch.float16)

Should we modify the test to also cast buffers to torch.float16?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that calling expanded_kernel.to(x.dtype) is probably better than x = x.to(expanded_kernel.dtype) in the Wan Animate code above. I also agree with #13339 (comment) that any changes to the Wan Animate code should be spun off into a separate PR, and this PR should focus on changes to test_save_load_float16.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably the call to module.named_parameters() below in test_save_load_float16 is too restrictive, as named_parameters() will not return buffers:

Should we modify the test to also cast buffers to torch.float16?

Hey @dg845 , you are right. I modified the test code to cast all buffers to torch.float16 as well. Except for rope, because rope here is not affected by from_pretrained(tmpdir, torch_dtype=torch.float16).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that calling expanded_kernel.to(x.dtype) is probably better than x = x.to(expanded_kernel.dtype) in the Wan Animate code above. I also agree with #13339 (comment) that any changes to the Wan Animate code should be spun off into a separate PR, and this PR should focus on changes to test_save_load_float16.

I have already made the modification here in the separate PR #13364.

Comment on lines +1450 to +1452
for key in components:
if "text_encoder" in key and hasattr(components[key], "eval"):
components[key].eval()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can safely apply eval if the component contains the eval attr.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sayakpaul sayakpaul requested a review from dg845 March 27, 2026 06:33
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Copy Markdown
Collaborator

@dg845 dg845 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants