Fix: ensure consistent dtype and eval mode in pipeline save/load tests#13339
Fix: ensure consistent dtype and eval mode in pipeline save/load tests#13339YangKai0616 wants to merge 9 commits intohuggingface:mainfrom
Conversation
| expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1) | ||
| x = x.to(expanded_kernel.dtype) | ||
| x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels) | ||
| x = F.conv2d(x, expanded_kernel.to(x.dtype), padding=self.blur_padding, groups=self.in_channels) |
There was a problem hiding this comment.
blur_kernel is registered with persistent=False, so it's not saved into the checkpoint. After from_pretrained, it's always re-initialized in float32, while before saving it was float16 (due to .half()). This dtype mismatch causes the two inference runs to take different numerical paths.
We can print expanded_kernel.dtype in the test tests/pipelines/wan/test_wan_animate.py::WanAnimatePipelineFastTests::test_save_load_float16 — it outputs float16 before save and float32 after load.
There was a problem hiding this comment.
Then I propose we xfail that test under test suite of Wan Animate and tackle it in a separate PR. We might want to use modules_to_not_convert or something.
There was a problem hiding this comment.
Done.
In addition, regarding the modification of the blur_kernel, I think that since it is a fixed mathematical constant, it is reasonable to directly adjust the dtype to fit the computational flow. What do you think? If you agree, I will submit a new PR to modify this part and restore the xfail in the test.
There was a problem hiding this comment.
Arguably the call to module.named_parameters() below in test_save_load_float16 is too restrictive, as named_parameters() will not return buffers:
diffusers/tests/pipelines/test_pipelines_common.py
Lines 1438 to 1445 in 1fe2125
Should we modify the test to also cast buffers to torch.float16?
There was a problem hiding this comment.
I agree that calling expanded_kernel.to(x.dtype) is probably better than x = x.to(expanded_kernel.dtype) in the Wan Animate code above. I also agree with #13339 (comment) that any changes to the Wan Animate code should be spun off into a separate PR, and this PR should focus on changes to test_save_load_float16.
There was a problem hiding this comment.
Arguably the call to
module.named_parameters()below intest_save_load_float16is too restrictive, asnamed_parameters()will not return buffers:Should we modify the test to also cast buffers to
torch.float16?
Hey @dg845 , you are right. I modified the test code to cast all buffers to torch.float16 as well. Except for rope, because rope here is not affected by from_pretrained(tmpdir, torch_dtype=torch.float16).
There was a problem hiding this comment.
I agree that calling
expanded_kernel.to(x.dtype)is probably better thanx = x.to(expanded_kernel.dtype)in the Wan Animate code above. I also agree with #13339 (comment) that any changes to the Wan Animate code should be spun off into a separate PR, and this PR should focus on changes totest_save_load_float16.
I have already made the modification here in the separate PR #13364.
| for key in components: | ||
| if "text_encoder" in key and hasattr(components[key], "eval"): | ||
| components[key].eval() |
There was a problem hiding this comment.
We can safely apply eval if the component contains the eval attr.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Issue:
The
test_save_load_float16test has a hidden non-determinism problem:text_encodercontainsdropoutlayers). However, after loading withfrom_pretrained, the components are in eval mode.blur_kernelinMotionConv2d) may have inconsistent dtypes during the save/load process, leading to inconsistent inference results.Although the test usually "luckily" passes on CUDA (because dropout errors are diluted by numerical propagation), this is fundamentally a correctness issue in both the test design and the implementation.
Fix:
text_encodercomponent in the pipeline to eval mode, to be consistent with the handling in othertest_save_loadXXtests such as test_save_load_local.blur_kernelto the input dtype. This ensures consistency before and after save/load operations and avoids dtype mismatches. (blur_kernelis essentially a fixed mathematical constant, not a learnable parameter. It should adapt to the dtype of the computation stream, rather than forcing the computation stream's data to adapt to it.)Hi @sayakpaul , please help review, thanks!