Skip to content

[Bug] WanPipeline height and width #13034

@Jayce-Ping

Description

@Jayce-Ping

Describe the bug

WanPipeline, WanImageToVideoPipeline, checks height and width to be multiple of 16, but actually must be multiple of 32.

Two possible solutions:

(1) fix check_inputs method to force height and width to be multiples of 32.

(2) Add the following code in inference method to fix height and width auto:

        multiple_of = self.vae_scale_factor_spatial * 2
        calc_height = height // multiple_of * multiple_of
        calc_width = width // multiple_of * multiple_of
        if height != calc_height or width != calc_width:
            logger.warning(
                f"`height` and `width` must be multiples of {multiple_of} for proper patchification. "
                f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
            )
            height, width = calc_height, calc_width

Reproduction

import torch
import numpy as np
from diffusers import WanPipeline, AutoencoderKLWan, WanTransformer3DModel, UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image

dtype = torch.bfloat16
device = "cuda"

model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=dtype)
pipe.to(device)

height = 240 # 240 // 16 == 15
width = 240 # 240 // 16 == 15
num_frames = 121
num_inference_steps = 50
guidance_scale = 5.0


prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage."
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"

output = pipe(
    prompt=prompt,
    negative_prompt=negative_prompt,
    height=height,
    width=width,
    num_frames=num_frames,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
).frames[0]
export_to_video(output, "5bit2v_output.mp4", fps=24)

Logs

Traceback (most recent call last):
  File "/home/users/astar/cfar/stuchengyou/Flow-Factory/test.py", line 24, in <module>
    output = pipe(
             ^^^^^
  File "/home/users/astar/cfar/stuchengyou/jcy/.conda/envs/ff-flash/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/Flow-Factory/diffusers/src/diffusers/pipelines/wan/pipeline_wan.py", line 594, in __call__
    noise_pred = current_model(
                 ^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/jcy/.conda/envs/ff-flash/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/jcy/.conda/envs/ff-flash/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/Flow-Factory/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 688, in forward
    hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/jcy/.conda/envs/ff-flash/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/jcy/.conda/envs/ff-flash/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/users/astar/cfar/stuchengyou/Flow-Factory/diffusers/src/diffusers/models/transformers/transformer_wan.py", line 485, in forward
    norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
                          ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~
RuntimeError: The size of tensor a (98) must match the size of tensor b (128) at non-singleton dimension 1

System Info

diffusers-cli env

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

  • 🤗 Diffusers version: 0.37.0.dev0
  • Platform: Linux-5.15.0-113-generic-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.12.12
  • PyTorch version (GPU?): 2.8.0+cu126 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.6
  • Accelerate version: 1.12.0
  • PEFT version: 0.18.1
  • Bitsandbytes version: not installed
  • Safetensors version: 0.7.0
  • xFormers version: not installed
  • Accelerator: NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
    NVIDIA H200, 143771 MiB
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: <N/A>

Who can help?

@DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions