Skip to content

[tests] tests for compilation + quantization (bnb) #11672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 11, 2025
12 changes: 12 additions & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,18 @@ def decorator(test_case):
return decorator


def require_torch_version_greater(torch_version):
"""Decorator marking a test that requires torch with a specific version greater."""

def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
)(test_case)

return decorator


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
Expand Down
43 changes: 43 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
FluxTransformer2DModel,
SD3Transformer2DModel,
)
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
Expand All @@ -44,6 +45,8 @@
require_peft_backend,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_version_greater,
require_transformers_version_greater,
slow,
torch_device,
Expand Down Expand Up @@ -855,3 +858,43 @@ def test_fp4_double_unsafe(self):

def test_fp4_double_safe(self):
self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True)


@require_torch_gpu
@slow
class Bnb4BitCompileTests(unittest.TestCase):
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
torch.compiler.reset()

@require_torch_version_greater("2.7.1")
def test_torch_compile_4bit(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True

quantization_config = PipelineQuantizationConfig(
quant_backend="bitsandbytes_4bit",
quant_kwargs={
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_compute_dtype": torch.bfloat16,
},
components_to_quantize=["transformer"],
)
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
).to("cuda")
pipe.transformer.compile(fullgraph=True)

for _ in range(2):
# with torch._dynamo.config.patch(error_on_recompile=True):
pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256)
Loading