diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index ae5a6e6e91eb..5cbe5ff27780 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -291,6 +291,18 @@ def decorator(test_case): return decorator +def require_torch_version_greater(torch_version): + """Decorator marking a test that requires torch with a specific version greater.""" + + def decorator(test_case): + correct_torch_version = is_torch_available() and is_torch_version(">", torch_version) + return unittest.skipUnless( + correct_torch_version, f"test requires torch with the version greater than {torch_version}" + )(test_case) + + return decorator + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index ac1b0cf3ce6b..2d8b9f698bfe 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -30,6 +30,7 @@ FluxTransformer2DModel, SD3Transformer2DModel, ) +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -44,11 +45,14 @@ require_peft_backend, require_torch, require_torch_accelerator, + require_torch_version_greater, require_transformers_version_greater, slow, torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -855,3 +859,26 @@ def test_fp4_double_unsafe(self): def test_fp4_double_safe(self): self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) + + +@require_torch_version_greater("2.7.1") +class Bnb4BitCompileTests(QuantCompileTests): + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) + + def test_torch_compile(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config) + + def test_torch_compile_with_cpu_offload(self): + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + def test_torch_compile_with_group_offload(self): + super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bb0702c00bd9..b15a9f72a8f6 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -46,11 +46,14 @@ require_peft_version_greater, require_torch, require_torch_accelerator, + require_torch_version_greater_equal, require_transformers_version_greater, slow, torch_device, ) +from ..test_torch_compile_utils import QuantCompileTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -821,3 +824,27 @@ def test_serialization_sharded(self): out_0 = self.model_0(**inputs)[0] out_1 = model_1(**inputs)[0] self.assertTrue(torch.equal(out_0, out_1)) + + +@require_torch_version_greater_equal("2.6.0") +class Bnb8BitCompileTests(QuantCompileTests): + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ) + + def test_torch_compile(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) + + def test_torch_compile_with_cpu_offload(self): + super()._test_torch_compile_with_cpu_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 + ) + + @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") + def test_torch_compile_with_group_offload(self): + super()._test_torch_compile_with_group_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 + ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py new file mode 100644 index 000000000000..1ae77b27d7cd --- /dev/null +++ b/tests/quantization/test_torch_compile_utils.py @@ -0,0 +1,87 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from diffusers import DiffusionPipeline +from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device + + +@require_torch_gpu +@slow +class QuantCompileTests(unittest.TestCase): + quantization_config = None + + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def _init_pipeline(self, quantization_config, torch_dtype): + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + quantization_config=quantization_config, + torch_dtype=torch_dtype, + ) + return pipe + + def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") + # import to ensure fullgraph True + pipe.transformer.compile(fullgraph=True) + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype) + pipe.enable_model_cpu_offload() + pipe.transformer.compile() + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + torch._dynamo.config.cache_size_limit = 10000 + + pipe = self._init_pipeline(quantization_config, torch_dtype) + group_offload_kwargs = { + "onload_device": torch.device("cuda"), + "offload_device": torch.device("cpu"), + "offload_type": "leaf_level", + "use_stream": True, + "non_blocking": True, + } + pipe.transformer.enable_group_offload(**group_offload_kwargs) + pipe.transformer.compile() + for name, component in pipe.components.items(): + if name != "transformer" and isinstance(component, torch.nn.Module): + if torch.device(component.device).type == "cpu": + component.to("cuda") + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)