@@ -421,9 +421,7 @@ def check_inputs(
421421 " undefined when mode is `replace`."
422422 )
423423 if mode == "replace" and (not isinstance (background_video , list ) or not isinstance (mask_video , list )):
424- raise ValueError (
425- "`background_video` and `mask_video` must be lists of PIL images when mode is `replace`."
426- )
424+ raise ValueError ("`background_video` and `mask_video` must be lists of PIL images when mode is `replace`." )
427425
428426 if height % 16 != 0 or width % 16 != 0 :
429427 raise ValueError (f"`height` and `width` have to be divisible by 16 but are { height } and { width } ." )
@@ -609,7 +607,7 @@ def prepare_prev_segment_cond_latents(
609607 )
610608 prev_segment_cond_video = prev_segment_cond_video .unflatten (0 , (batch_size , - 1 )).transpose (1 , 2 )
611609
612- # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
610+ # Fill the remaining part of the cond video segment with zeros (if animating) or the background video (if
613611 # replacing).
614612 if task == "replace" :
615613 remaining_segment = background_video [:, :, prev_segment_cond_frames :].to (dtype )
@@ -626,7 +624,8 @@ def prepare_prev_segment_cond_latents(
626624 if isinstance (generator , list ):
627625 if data_batch_size == len (generator ):
628626 prev_segment_cond_latents = [
629- retrieve_latents (self .vae .encode (full_segment_cond_video [i ].unsqueeze (0 )), g , sample_mode ) for i , g in enumerate (generator )
627+ retrieve_latents (self .vae .encode (full_segment_cond_video [i ].unsqueeze (0 )), g , sample_mode )
628+ for i , g in enumerate (generator )
630629 ]
631630 elif data_batch_size == 1 :
632631 # Like prepare_latents, assume len(generator) == batch_size
@@ -813,11 +812,11 @@ def __call__(
813812 face_video (`List[PIL.Image.Image]`):
814813 The input face video to condition the generation on. Must be a list of PIL images.
815814 background_video (`List[PIL.Image.Image]`, *optional*):
816- When mode is `"replace"`, the input background video to condition the generation on. Must be a list
817- of PIL images.
818- mask_video (`List[PIL.Image.Image]`, *optional*):
819- When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of
815+ When mode is `"replace"`, the input background video to condition the generation on. Must be a list of
820816 PIL images.
817+ mask_video (`List[PIL.Image.Image]`, *optional*):
818+ When mode is `"replace"`, the input mask video to condition the generation on. Must be a list of PIL
819+ images.
821820 prompt (`str` or `List[str]`, *optional*):
822821 The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
823822 instead.
@@ -828,16 +827,16 @@ def __call__(
828827 mode (`str`, defaults to `"animation"`):
829828 The mode of the generation. Choose between `"animate"` and `"replace"`.
830829 prev_segment_conditioning_frames (`int`, defaults to `1`):
831- The number of frames from the previous video segment to be used for temporal guidance. Recommended
832- to be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
830+ The number of frames from the previous video segment to be used for temporal guidance. Recommended to
831+ be 1 or 5. In general, should be 4N + 1, where N is a non-negative integer.
833832 height (`int`, defaults to `720`):
834833 The height of the generated video.
835834 width (`int`, defaults to `1280`):
836835 The width of the generated video.
837836 segment_frame_length (`int`, defaults to `77`):
838- The number of frames in each generated video segment. The total frames of video generated will be
839- equal to the number of frames in `pose_video`; we will generate the video in segments until we have
840- hit this length. In general, should be 4N + 1, where N is a non-negative integer.
837+ The number of frames in each generated video segment. The total frames of video generated will be equal
838+ to the number of frames in `pose_video`; we will generate the video in segments until we have hit this
839+ length. In general, should be 4N + 1, where N is a non-negative integer.
841840 num_inference_steps (`int`, defaults to `20`):
842841 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
843842 expense of slower inference.
@@ -846,8 +845,8 @@ def __call__(
846845 Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
847846 of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
848847 `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
849- the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in
850- Wan Animate inference.
848+ the text `prompt`, usually at the expense of lower image quality. By default, CFG is not used in Wan
849+ Animate inference.
851850 num_videos_per_prompt (`int`, *optional*, defaults to 1):
852851 The number of images to generate per prompt.
853852 generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -923,7 +922,9 @@ def __call__(
923922 f"`segment_frame_length - 1` has to be divisible by { self .vae_scale_factor_temporal } . Rounding to the"
924923 f" nearest number."
925924 )
926- segment_frame_length = segment_frame_length // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
925+ segment_frame_length = (
926+ segment_frame_length // self .vae_scale_factor_temporal * self .vae_scale_factor_temporal + 1
927+ )
927928 segment_frame_length = max (segment_frame_length , 1 )
928929
929930 self ._guidance_scale = guidance_scale
0 commit comments