Skip to content

Commit 0b76728

Browse files
DN6sayakpaul
andauthored
Refactor Model Tests (#12822)
* update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 973e334 commit 0b76728

File tree

16 files changed

+5972
-103
lines changed

16 files changed

+5972
-103
lines changed

tests/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@
3232

3333
def pytest_configure(config):
3434
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
35+
config.addinivalue_line("markers", "lora: marks tests for LoRA/PEFT functionality")
36+
config.addinivalue_line("markers", "ip_adapter: marks tests for IP Adapter functionality")
37+
config.addinivalue_line("markers", "training: marks tests for training functionality")
38+
config.addinivalue_line("markers", "attention: marks tests for attention processor functionality")
39+
config.addinivalue_line("markers", "memory: marks tests for memory optimization functionality")
40+
config.addinivalue_line("markers", "cpu_offload: marks tests for CPU offloading functionality")
41+
config.addinivalue_line("markers", "group_offload: marks tests for group offloading functionality")
42+
config.addinivalue_line("markers", "compile: marks tests for torch.compile functionality")
43+
config.addinivalue_line("markers", "single_file: marks tests for single file checkpoint loading")
44+
config.addinivalue_line("markers", "quantization: marks tests for quantization functionality")
45+
config.addinivalue_line("markers", "bitsandbytes: marks tests for BitsAndBytes quantization functionality")
46+
config.addinivalue_line("markers", "quanto: marks tests for Quanto quantization functionality")
47+
config.addinivalue_line("markers", "torchao: marks tests for TorchAO quantization functionality")
48+
config.addinivalue_line("markers", "gguf: marks tests for GGUF quantization functionality")
49+
config.addinivalue_line("markers", "modelopt: marks tests for NVIDIA ModelOpt quantization functionality")
50+
config.addinivalue_line("markers", "context_parallel: marks tests for context parallel inference functionality")
3551
config.addinivalue_line("markers", "slow: mark test as slow")
3652
config.addinivalue_line("markers", "nightly: mark test as nightly")
3753

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
from .attention import AttentionTesterMixin
2+
from .cache import (
3+
CacheTesterMixin,
4+
FasterCacheConfigMixin,
5+
FasterCacheTesterMixin,
6+
FirstBlockCacheConfigMixin,
7+
FirstBlockCacheTesterMixin,
8+
PyramidAttentionBroadcastConfigMixin,
9+
PyramidAttentionBroadcastTesterMixin,
10+
)
11+
from .common import BaseModelTesterConfig, ModelTesterMixin
12+
from .compile import TorchCompileTesterMixin
13+
from .ip_adapter import IPAdapterTesterMixin
14+
from .lora import LoraHotSwappingForModelTesterMixin, LoraTesterMixin
15+
from .memory import CPUOffloadTesterMixin, GroupOffloadTesterMixin, LayerwiseCastingTesterMixin, MemoryTesterMixin
16+
from .parallelism import ContextParallelTesterMixin
17+
from .quantization import (
18+
BitsAndBytesCompileTesterMixin,
19+
BitsAndBytesConfigMixin,
20+
BitsAndBytesTesterMixin,
21+
GGUFCompileTesterMixin,
22+
GGUFConfigMixin,
23+
GGUFTesterMixin,
24+
ModelOptCompileTesterMixin,
25+
ModelOptConfigMixin,
26+
ModelOptTesterMixin,
27+
QuantizationCompileTesterMixin,
28+
QuantizationTesterMixin,
29+
QuantoCompileTesterMixin,
30+
QuantoConfigMixin,
31+
QuantoTesterMixin,
32+
TorchAoCompileTesterMixin,
33+
TorchAoConfigMixin,
34+
TorchAoTesterMixin,
35+
)
36+
from .single_file import SingleFileTesterMixin
37+
from .training import TrainingTesterMixin
38+
39+
40+
__all__ = [
41+
"AttentionTesterMixin",
42+
"BaseModelTesterConfig",
43+
"BitsAndBytesCompileTesterMixin",
44+
"BitsAndBytesConfigMixin",
45+
"BitsAndBytesTesterMixin",
46+
"CacheTesterMixin",
47+
"ContextParallelTesterMixin",
48+
"CPUOffloadTesterMixin",
49+
"FasterCacheConfigMixin",
50+
"FasterCacheTesterMixin",
51+
"FirstBlockCacheConfigMixin",
52+
"FirstBlockCacheTesterMixin",
53+
"GGUFCompileTesterMixin",
54+
"GGUFConfigMixin",
55+
"GGUFTesterMixin",
56+
"GroupOffloadTesterMixin",
57+
"IPAdapterTesterMixin",
58+
"LayerwiseCastingTesterMixin",
59+
"LoraHotSwappingForModelTesterMixin",
60+
"LoraTesterMixin",
61+
"MemoryTesterMixin",
62+
"ModelOptCompileTesterMixin",
63+
"ModelOptConfigMixin",
64+
"ModelOptTesterMixin",
65+
"ModelTesterMixin",
66+
"PyramidAttentionBroadcastConfigMixin",
67+
"PyramidAttentionBroadcastTesterMixin",
68+
"QuantizationCompileTesterMixin",
69+
"QuantizationTesterMixin",
70+
"QuantoCompileTesterMixin",
71+
"QuantoConfigMixin",
72+
"QuantoTesterMixin",
73+
"SingleFileTesterMixin",
74+
"TorchAoCompileTesterMixin",
75+
"TorchAoConfigMixin",
76+
"TorchAoTesterMixin",
77+
"TorchCompileTesterMixin",
78+
"TrainingTesterMixin",
79+
]
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
18+
import pytest
19+
import torch
20+
21+
from diffusers.models.attention import AttentionModuleMixin
22+
from diffusers.models.attention_processor import (
23+
AttnProcessor,
24+
)
25+
26+
from ...testing_utils import (
27+
assert_tensors_close,
28+
backend_empty_cache,
29+
is_attention,
30+
torch_device,
31+
)
32+
33+
34+
@is_attention
35+
class AttentionTesterMixin:
36+
"""
37+
Mixin class for testing attention processor and module functionality on models.
38+
39+
Tests functionality from AttentionModuleMixin including:
40+
- Attention processor management (set/get)
41+
- QKV projection fusion/unfusion
42+
- Attention backends (XFormers, NPU, etc.)
43+
44+
Expected from config mixin:
45+
- model_class: The model class to test
46+
47+
Expected methods from config mixin:
48+
- get_init_dict(): Returns dict of arguments to initialize the model
49+
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
50+
51+
Pytest mark: attention
52+
Use `pytest -m "not attention"` to skip these tests
53+
"""
54+
55+
def setup_method(self):
56+
gc.collect()
57+
backend_empty_cache(torch_device)
58+
59+
def teardown_method(self):
60+
gc.collect()
61+
backend_empty_cache(torch_device)
62+
63+
@torch.no_grad()
64+
def test_fuse_unfuse_qkv_projections(self, atol=1e-3, rtol=0):
65+
init_dict = self.get_init_dict()
66+
inputs_dict = self.get_dummy_inputs()
67+
model = self.model_class(**init_dict)
68+
model.to(torch_device)
69+
model.eval()
70+
71+
if not hasattr(model, "fuse_qkv_projections"):
72+
pytest.skip("Model does not support QKV projection fusion.")
73+
74+
output_before_fusion = model(**inputs_dict, return_dict=False)[0]
75+
76+
model.fuse_qkv_projections()
77+
78+
has_fused_projections = False
79+
for module in model.modules():
80+
if isinstance(module, AttentionModuleMixin):
81+
if hasattr(module, "to_qkv") or hasattr(module, "to_kv"):
82+
has_fused_projections = True
83+
assert module.fused_projections, "fused_projections flag should be True"
84+
break
85+
86+
if has_fused_projections:
87+
output_after_fusion = model(**inputs_dict, return_dict=False)[0]
88+
89+
assert_tensors_close(
90+
output_before_fusion,
91+
output_after_fusion,
92+
atol=atol,
93+
rtol=rtol,
94+
msg="Output should not change after fusing projections",
95+
)
96+
97+
model.unfuse_qkv_projections()
98+
99+
for module in model.modules():
100+
if isinstance(module, AttentionModuleMixin):
101+
assert not hasattr(module, "to_qkv"), "to_qkv should be removed after unfusing"
102+
assert not hasattr(module, "to_kv"), "to_kv should be removed after unfusing"
103+
assert not module.fused_projections, "fused_projections flag should be False"
104+
105+
output_after_unfusion = model(**inputs_dict, return_dict=False)[0]
106+
107+
assert_tensors_close(
108+
output_before_fusion,
109+
output_after_unfusion,
110+
atol=atol,
111+
rtol=rtol,
112+
msg="Output should match original after unfusing projections",
113+
)
114+
115+
def test_get_set_processor(self):
116+
init_dict = self.get_init_dict()
117+
model = self.model_class(**init_dict)
118+
model.to(torch_device)
119+
120+
# Check if model has attention processors
121+
if not hasattr(model, "attn_processors"):
122+
pytest.skip("Model does not have attention processors.")
123+
124+
# Test getting processors
125+
processors = model.attn_processors
126+
assert isinstance(processors, dict), "attn_processors should return a dict"
127+
assert len(processors) > 0, "Model should have at least one attention processor"
128+
129+
# Test that all processors can be retrieved via get_processor
130+
for module in model.modules():
131+
if isinstance(module, AttentionModuleMixin):
132+
processor = module.get_processor()
133+
assert processor is not None, "get_processor should return a processor"
134+
135+
# Test setting a new processor
136+
new_processor = AttnProcessor()
137+
module.set_processor(new_processor)
138+
retrieved_processor = module.get_processor()
139+
assert retrieved_processor is new_processor, "Retrieved processor should be the same as the one set"
140+
141+
def test_attention_processor_dict(self):
142+
init_dict = self.get_init_dict()
143+
model = self.model_class(**init_dict)
144+
model.to(torch_device)
145+
146+
if not hasattr(model, "set_attn_processor"):
147+
pytest.skip("Model does not support setting attention processors.")
148+
149+
# Get current processors
150+
current_processors = model.attn_processors
151+
152+
# Create a dict of new processors
153+
new_processors = {key: AttnProcessor() for key in current_processors.keys()}
154+
155+
# Set processors using dict
156+
model.set_attn_processor(new_processors)
157+
158+
# Verify all processors were set
159+
updated_processors = model.attn_processors
160+
for key in current_processors.keys():
161+
assert type(updated_processors[key]) == AttnProcessor, f"Processor {key} should be AttnProcessor"
162+
163+
def test_attention_processor_count_mismatch_raises_error(self):
164+
init_dict = self.get_init_dict()
165+
model = self.model_class(**init_dict)
166+
model.to(torch_device)
167+
168+
if not hasattr(model, "set_attn_processor"):
169+
pytest.skip("Model does not support setting attention processors.")
170+
171+
# Get current processors
172+
current_processors = model.attn_processors
173+
174+
# Create a dict with wrong number of processors
175+
wrong_processors = {list(current_processors.keys())[0]: AttnProcessor()}
176+
177+
# Verify error is raised
178+
with pytest.raises(ValueError) as exc_info:
179+
model.set_attn_processor(wrong_processors)
180+
181+
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"

0 commit comments

Comments
 (0)