Skip to content

Commit d9c6bc6

Browse files
committed
make style and make quality
1 parent 80d9f8b commit d9c6bc6

File tree

4 files changed

+45
-34
lines changed

4 files changed

+45
-34
lines changed

scripts/convert_wan_to_diffusers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import argparse
2-
import math
32
import pathlib
43
from typing import Any, Dict, Tuple
54

@@ -582,7 +581,6 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
582581
"ffn_dim": 13824,
583582
"freq_dim": 256,
584583
"in_channels": 36,
585-
"motion_encoder_dim": 512,
586584
"num_attention_heads": 40,
587585
"num_layers": 40,
588586
"out_channels": 16,

src/diffusers/models/transformers/transformer_wan_animate.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,15 @@
4040

4141

4242
WAN_ANIMATE_MOTION_ENCODER_CHANNEL_SIZES = {
43-
"4": 512, "8": 512, "16": 512, "32": 512, "64": 256, "128": 128, "256": 64, "512": 32, "1024": 16
43+
"4": 512,
44+
"8": 512,
45+
"16": 512,
46+
"32": 512,
47+
"64": 256,
48+
"128": 128,
49+
"256": 64,
50+
"512": 32,
51+
"1024": 16,
4452
}
4553

4654

@@ -77,7 +85,11 @@ def __init__(self, negative_slope: float = 0.2, scale: float = 2**0.5, bias_chan
7785
self.channels = bias_channels
7886

7987
if self.channels is not None:
80-
self.bias = nn.Parameter(torch.zeros(self.channels,))
88+
self.bias = nn.Parameter(
89+
torch.zeros(
90+
self.channels,
91+
)
92+
)
8193
else:
8294
self.bias = None
8395

@@ -121,13 +133,13 @@ def __init__(
121133
# Normalize kernel
122134
kernel = kernel / kernel.sum()
123135
if blur_upsample_factor > 1:
124-
kernel = kernel * (blur_upsample_factor ** 2)
136+
kernel = kernel * (blur_upsample_factor**2)
125137
self.register_buffer("blur_kernel", kernel, persistent=False)
126138
self.blur = True
127139

128140
# Main Conv2d parameters (with scale factor)
129141
self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
130-
self.scale = 1 / math.sqrt(in_channels * kernel_size ** 2)
142+
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
131143

132144
self.stride = stride
133145
self.padding = padding
@@ -161,8 +173,8 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
161173

162174
def __repr__(self):
163175
return (
164-
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
165-
f' kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
176+
f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
177+
f" kernel_size={self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
166178
)
167179

168180

@@ -179,7 +191,7 @@ def __init__(
179191

180192
# Linear weight with scale factor
181193
self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
182-
self.scale = (1 / math.sqrt(in_dim))
194+
self.scale = 1 / math.sqrt(in_dim)
183195

184196
# If an activation is present, the bias will be fused to it
185197
if bias and not self.use_activation:
@@ -200,8 +212,8 @@ def forward(self, input: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
200212

201213
def __repr__(self):
202214
return (
203-
f'{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},'
204-
f' bias={self.bias is not None})'
215+
f"{self.__class__.__name__}(in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]},"
216+
f" bias={self.bias is not None})"
205217
)
206218

207219

@@ -616,7 +628,7 @@ def __init__(
616628
# TODO: should this always be true?
617629
assert in_channels == 2 * latent_channels + 4, "in_channels should be 2 * latent_channels + 4"
618630
else:
619-
raise ValueError(f"At least one of `in_channels` and `latent_channels` must be supplied.")
631+
raise ValueError("At least one of `in_channels` and `latent_channels` must be supplied.")
620632
out_channels = out_channels or latent_channels
621633

622634
# 1. Patch & position embedding
@@ -722,8 +734,8 @@ def forward(
722734
Args:
723735
hidden_states (`torch.Tensor` of shape `(B, 2C + 4, T + 1, H, W)`):
724736
Input noisy video latents of shape `(B, 2C + 4, T + 1, H, W)`, where B is the batch size, C is the
725-
number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment,
726-
H is the latent height, and W is the latent width.
737+
number of latent channels (16 for Wan VAE), T is the number of latent frames in an inference segment, H
738+
is the latent height, and W is the latent width.
727739
timestep: (`torch.LongTensor`):
728740
The current timestep in the denoising loop.
729741
encoder_hidden_states (`torch.Tensor`):

src/diffusers/pipelines/wan/pipeline_wan_animate.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,7 @@ def check_inputs(
421421
" undefined when mode is `replace`."
422422
)
423423
if mode == "replace" and (not isinstance(background_video, list) or not isinstance(mask_video, list)):
424-
raise ValueError(
425-
"`background_video` and `mask_video` must be lists of PIL images when mode is `replace`."
426-
)
424+
raise ValueError("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`.")
427425

428426
if height % 16 != 0 or width % 16 != 0:
429427
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -609,7 +607,7 @@ def prepare_prev_segment_cond_latents(
609607
)
610608
prev_segment_cond_video = prev_segment_cond_video.unflatten(0, (batch_size, -1)).transpose(1, 2)
611609

612-
# Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
610+
# Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
613611
# replacing).
614612
if task == "replace":
615613
remaining_segment = background_video[:, :, prev_segment_cond_frames:].to(dtype)
@@ -626,7 +624,8 @@ def prepare_prev_segment_cond_latents(
626624
if isinstance(generator, list):
627625
if data_batch_size == len(generator):
628626
prev_segment_cond_latents = [
629-
retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode) for i, g in enumerate(generator)
627+
retrieve_latents(self.vae.encode(full_segment_cond_video[i].unsqueeze(0)), g, sample_mode)
628+
for i, g in enumerate(generator)
630629
]
631630
elif data_batch_size == 1:
632631
# Like prepare_latents, assume len(generator) == batch_size
@@ -813,11 +812,11 @@ def __call__(
813812
face_video (`List[PIL.Image.Image]`):
814813
The input face video to condition the generation on. Must be a list of PIL images.
815814
background_video (`List[PIL.Image.Image]`, *optional*):
816-
When mode is `"replace"`, the input background video to condition the generation on. Must be a list
817-
of PIL images.
818-
mask_video (`List[PIL.Image.Image]`, *optional*):
819-
When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of
815+
When mode is `"replace"`, the input background video to condition the generation on. Must be a list of
820816
PIL images.
817+
mask_video (`List[PIL.Image.Image]`, *optional*):
818+
When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL
819+
images.
821820
prompt (`str` or `List[str]`, *optional*):
822821
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
823822
instead.
@@ -828,16 +827,16 @@ def __call__(
828827
mode (`str`, defaults to `"animation"`):
829828
The mode of the generation. Choose between `"animate"` and `"replace"`.
830829
prev_segment_conditioning_frames (`int`, defaults to `1`):
831-
The number of frames from the previous video segment to be used for temporal guidance. Recommended
832-
to be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
830+
The number of frames from the previous video segment to be used for temporal guidance. Recommended to
831+
be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
833832
height (`int`, defaults to `720`):
834833
The height of the generated video.
835834
width (`int`, defaults to `1280`):
836835
The width of the generated video.
837836
segment_frame_length (`int`, defaults to `77`):
838-
The number of frames in each generated video segment. The total frames of video generated will be
839-
equal to the number of frames in `pose_video`; we will generate the video in segments until we have
840-
hit this length. In general, should be 4N + 1, where N is a non-negative integer.
837+
The number of frames in each generated video segment. The total frames of video generated will be equal
838+
to the number of frames in `pose_video`; we will generate the video in segments until we have hit this
839+
length. In general, should be 4N + 1, where N is a non-negative integer.
841840
num_inference_steps (`int`, defaults to `20`):
842841
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
843842
expense of slower inference.
@@ -846,8 +845,8 @@ def __call__(
846845
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
847846
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
848847
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
849-
the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in
850-
Wan Animate inference.
848+
the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan
849+
Animate inference.
851850
num_videos_per_prompt (`int`, *optional*, defaults to 1):
852851
The number of images to generate per prompt.
853852
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -923,7 +922,9 @@ def __call__(
923922
f"`segment_frame_length - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the"
924923
f" nearest number."
925924
)
926-
segment_frame_length = segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
925+
segment_frame_length = (
926+
segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
927+
)
927928
segment_frame_length = max(segment_frame_length, 1)
928929

929930
self._guidance_scale = guidance_scale

tests/models/transformers/test_models_transformer_wan_animate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@ def dummy_input(self):
5555
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
5656
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
5757
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
58-
face_pixel_values = torch.randn(
59-
(batch_size, 3, inference_segment_length, face_height, face_width)
60-
).to(torch_device)
58+
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
59+
torch_device
60+
)
6161

6262
return {
6363
"hidden_states": hidden_states,

0 commit comments

Comments
 (0)