Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 137 additions & 40 deletions tests/models/transformers/test_models_transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,57 +12,52 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import torch

from diffusers import WanTransformer3DModel

from ...testing_utils import (
enable_full_determinism,
torch_device,
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


enable_full_determinism()


class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True

class WanTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
def model_class(self):
return WanTransformer3DModel

hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)

@property
def input_shape(self):
return (4, 1, 16, 16)
def main_input_name(self) -> str:
return "hidden_states"

@property
def output_shape(self):
return (4, 1, 16, 16)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
Expand All @@ -76,16 +71,118 @@ def prepare_init_args_and_inputs_for_common(self):
"qk_norm": "rms_norm_across_heads",
"rope_max_seq_len": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12

return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}


class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Wan Transformer 3D."""


class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Wan Transformer 3D."""


class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Wan Transformer 3D."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"WanTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = WanTransformer3DModel
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Wan Transformer 3D."""


class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Wan Transformer 3D."""


class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Wan Transformer 3D."""


class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Wan Transformer 3D."""


class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Wan Transformer 3D."""

@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"

@property
def torch_dtype(self):
return torch.bfloat16

def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.

Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}


class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Wan Transformer 3D."""

def prepare_init_args_and_inputs_for_common(self):
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
@property
def gguf_filename(self):
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"

@property
def torch_dtype(self):
return torch.bfloat16

def get_dummy_inputs(self):
"""Override to provide inputs matching the real Wan I2V model dimensions.

Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
"""
return {
"hidden_states": randn_tensor(
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states": randn_tensor(
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"encoder_hidden_states_image": randn_tensor(
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
}
Loading
Loading