From 6fe24149efe8378f335fb0f61e68c4d1d95be40a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 12:55:26 +0530 Subject: [PATCH 1/9] start adding compilation tests for quantization. --- tests/quantization/bnb/test_4bit.py | 34 +++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index ac1b0cf3ce6b..8d68eaf00c78 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -28,6 +28,7 @@ DiffusionPipeline, FluxControlPipeline, FluxTransformer2DModel, + PipelineQuantizationConfig, SD3Transformer2DModel, ) from diffusers.utils import is_accelerate_version, logging @@ -44,6 +45,8 @@ require_peft_backend, require_torch, require_torch_accelerator, + require_torch_gpu, + require_torch_version_greater_equal, require_transformers_version_greater, slow, torch_device, @@ -855,3 +858,34 @@ def test_fp4_double_unsafe(self): def test_fp4_double_safe(self): self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) + + +@require_torch_gpu +@slow +class Bnb4BitCompileTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + @require_torch_version_greater_equal("2.8") + def test_torch_compile_4bit(self): + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True}, + components_to_quantize=["transformer"], + ) + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-flux-pipe", quantization_config=quantization_config, torch_dtype=torch.bfloat16 + ).to("cuda") + pipe.transformer.compile(fullgraph=True) + + for _ in range(2): + pipe("a dog", num_inference_steps=4, max_sequence_length=16) From 29cca994ea284963e2d33252fb55bf806a966b27 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Jun 2025 13:40:27 +0530 Subject: [PATCH 2/9] fixes --- src/diffusers/utils/testing_utils.py | 12 ++++++++++++ tests/quantization/bnb/test_4bit.py | 21 +++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e19a9f83fdb9..a1c32476210e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -291,6 +291,18 @@ def decorator(test_case): return decorator +def require_torch_version_greater(torch_version): + """Decorator marking a test that requires torch with a specific version greater.""" + + def decorator(test_case): + correct_torch_version = is_torch_available() and is_torch_version(">", torch_version) + return unittest.skipUnless( + correct_torch_version, f"test requires torch with the version greater than {torch_version}" + )(test_case) + + return decorator + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 8d68eaf00c78..0567e63bf7bb 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -28,9 +28,9 @@ DiffusionPipeline, FluxControlPipeline, FluxTransformer2DModel, - PipelineQuantizationConfig, SD3Transformer2DModel, ) +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -46,7 +46,7 @@ require_torch, require_torch_accelerator, require_torch_gpu, - require_torch_version_greater_equal, + require_torch_version_greater, require_transformers_version_greater, slow, torch_device, @@ -875,17 +875,26 @@ def tearDown(self): backend_empty_cache(torch_device) torch.compiler.reset() - @require_torch_version_greater_equal("2.8") + @require_torch_version_greater("2.7.1") def test_torch_compile_4bit(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + quantization_config = PipelineQuantizationConfig( quant_backend="bitsandbytes_4bit", - quant_kwargs={"load_in_4bit": True}, + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, components_to_quantize=["transformer"], ) pipe = DiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-flux-pipe", quantization_config=quantization_config, torch_dtype=torch.bfloat16 + "stabilityai/stable-diffusion-3-medium-diffusers", + quantization_config=quantization_config, + torch_dtype=torch.bfloat16, ).to("cuda") pipe.transformer.compile(fullgraph=True) for _ in range(2): - pipe("a dog", num_inference_steps=4, max_sequence_length=16) + # with torch._dynamo.config.patch(error_on_recompile=True): + pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256) From edf66b7953b0149a51f9b8fac37df5f4bbf084f2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 09:01:24 +0530 Subject: [PATCH 3/9] make common utility. --- tests/quantization/bnb/test_4bit.py | 34 +++---------- tests/quantization/bnb/test_mixed_int8.py | 19 +++++++ .../quantization/test_torch_compile_utils.py | 49 +++++++++++++++++++ 3 files changed, 74 insertions(+), 28 deletions(-) create mode 100644 tests/quantization/test_torch_compile_utils.py diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 0567e63bf7bb..7bff375bb177 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -45,13 +45,14 @@ require_peft_backend, require_torch, require_torch_accelerator, - require_torch_gpu, require_torch_version_greater, require_transformers_version_greater, slow, torch_device, ) +from ..utils import QuantCompileMiscTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -860,23 +861,9 @@ def test_fp4_double_safe(self): self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) -@require_torch_gpu -@slow -class Bnb4BitCompileTests(unittest.TestCase): - def setUp(self): - super().setUp() - gc.collect() - backend_empty_cache(torch_device) - torch.compiler.reset() - - def tearDown(self): - super().tearDown() - gc.collect() - backend_empty_cache(torch_device) - torch.compiler.reset() - +class Bnb4BitCompileTests(QuantCompileMiscTests): @require_torch_version_greater("2.7.1") - def test_torch_compile_4bit(self): + def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True quantization_config = PipelineQuantizationConfig( @@ -886,15 +873,6 @@ def test_torch_compile_4bit(self): "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16, }, - components_to_quantize=["transformer"], + components_to_quantize=["transformer", "text_encoder_2"], ) - pipe = DiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-3-medium-diffusers", - quantization_config=quantization_config, - torch_dtype=torch.bfloat16, - ).to("cuda") - pipe.transformer.compile(fullgraph=True) - - for _ in range(2): - # with torch._dynamo.config.patch(error_on_recompile=True): - pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256) + super().test_torch_compile(quantization_config=quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 98575b86cdcc..a6eb8e976ac6 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -28,6 +28,7 @@ SD3Transformer2DModel, logging, ) +from diffusers.quantizers import PipelineQuantizationConfig from diffusers.utils import is_accelerate_version from diffusers.utils.testing_utils import ( CaptureLogger, @@ -42,11 +43,14 @@ require_peft_version_greater, require_torch, require_torch_accelerator, + require_torch_version_greater_equal, require_transformers_version_greater, slow, torch_device, ) +from ..utils import QuantCompileMiscTests + def get_some_linear_layer(model): if model.__class__.__name__ in ["SD3Transformer2DModel", "FluxTransformer2DModel"]: @@ -773,3 +777,18 @@ def test_serialization_sharded(self): out_0 = self.model_0(**inputs)[0] out_1 = model_1(**inputs)[0] self.assertTrue(torch.equal(out_0, out_1)) + + +class Bnb8BitCompileTests(QuantCompileMiscTests): + @require_torch_version_greater_equal("2.6.0") + def test_torch_compile(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_8bit": True, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) + super().test_torch_compile(quantization_config=quantization_config, torch_dtype=torch.float16) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py new file mode 100644 index 000000000000..2fc519adaf7e --- /dev/null +++ b/tests/quantization/test_torch_compile_utils.py @@ -0,0 +1,49 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import gc +import unittest + +import torch + +from diffusers import DiffusionPipeline +from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device + + +@require_torch_gpu +@slow +class QuantCompileMiscTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def tearDown(self): + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + torch.compiler.reset() + + def test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = DiffusionPipeline.from_pretrained( + "stabilityai/stable-diffusion-3-medium-diffusers", + quantization_config=quantization_config, + torch_dtype=torch_dtype, + ).to("cuda") + pipe.transformer.compile(fullgraph=True) + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256) From 11cfd6c0817efb188934e511ad34d6bc163bb0e1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 10:07:06 +0530 Subject: [PATCH 4/9] modularize. --- tests/quantization/bnb/test_4bit.py | 28 +++++++++++-------- tests/quantization/bnb/test_mixed_int8.py | 22 +++++++++------ .../quantization/test_torch_compile_utils.py | 22 +++++++++++++-- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 7bff375bb177..e161f8e4a132 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -51,7 +51,7 @@ torch_device, ) -from ..utils import QuantCompileMiscTests +from ..test_torch_compile_utils import QuantCompileMiscTests def get_some_linear_layer(model): @@ -861,18 +861,24 @@ def test_fp4_double_safe(self): self.test_serialization(quant_type="fp4", double_quant=True, safe_serialization=True) +@require_torch_version_greater("2.7.1") class Bnb4BitCompileTests(QuantCompileMiscTests): - @require_torch_version_greater("2.7.1") + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=["transformer", "text_encoder_2"], + ) + def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config) - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_4bit", - quant_kwargs={ - "load_in_4bit": True, - "bnb_4bit_quant_type": "nf4", - "bnb_4bit_compute_dtype": torch.bfloat16, - }, - components_to_quantize=["transformer", "text_encoder_2"], + def test_torch_compile_with_cpu_offload(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile_with_cpu_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 ) - super().test_torch_compile(quantization_config=quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a6eb8e976ac6..a8f1d98610fe 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -49,7 +49,7 @@ torch_device, ) -from ..utils import QuantCompileMiscTests +from ..test_torch_compile_utils import QuantCompileMiscTests def get_some_linear_layer(model): @@ -779,16 +779,20 @@ def test_serialization_sharded(self): self.assertTrue(torch.equal(out_0, out_1)) +@require_torch_version_greater_equal("2.6.0") class Bnb8BitCompileTests(QuantCompileMiscTests): - @require_torch_version_greater_equal("2.6.0") + quantization_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_8bit", + quant_kwargs={"load_in_8bit": True}, + components_to_quantize=["transformer", "text_encoder_2"], + ) + def test_torch_compile(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) - quantization_config = PipelineQuantizationConfig( - quant_backend="bitsandbytes_8bit", - quant_kwargs={ - "load_in_8bit": True, - }, - components_to_quantize=["transformer", "text_encoder_2"], + def test_torch_compile_with_cpu_offload(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile_with_cpu_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 ) - super().test_torch_compile(quantization_config=quantization_config, torch_dtype=torch.float16) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 2fc519adaf7e..ab764b14bed7 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -24,6 +24,8 @@ @require_torch_gpu @slow class QuantCompileMiscTests(unittest.TestCase): + quantization_config = None + def setUp(self): super().setUp() gc.collect() @@ -36,14 +38,28 @@ def tearDown(self): backend_empty_cache(torch_device) torch.compiler.reset() - def test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): + def _init_pipeline(self, quantization_config, torch_dtype): pipe = DiffusionPipeline.from_pretrained( "stabilityai/stable-diffusion-3-medium-diffusers", quantization_config=quantization_config, torch_dtype=torch_dtype, - ).to("cuda") + ) + return pipe + + def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda") + # import to ensure fullgraph True pipe.transformer.compile(fullgraph=True) for _ in range(2): # small resolutions to ensure speedy execution. - pipe("a dog", num_inference_steps=4, max_sequence_length=16, height=256, width=256) + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype) + pipe.enable_model_cpu_offload() + pipe.transformer.compile() + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) From 0e4f1523aa2c0f9ade5cdd84b722ee9b5afc1f90 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 10:14:49 +0530 Subject: [PATCH 5/9] add group offloading+compile --- tests/quantization/bnb/test_4bit.py | 8 +++++--- tests/quantization/bnb/test_mixed_int8.py | 6 ++++++ tests/quantization/test_torch_compile_utils.py | 17 +++++++++++++++++ 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index e161f8e4a132..244c1d39d49a 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -879,6 +879,8 @@ def test_torch_compile(self): def test_torch_compile_with_cpu_offload(self): torch._dynamo.config.capture_dynamic_output_shape_ops = True - super()._test_torch_compile_with_cpu_offload( - quantization_config=self.quantization_config, torch_dtype=torch.float16 - ) + super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + + def test_torch_compile_with_group_offload(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index a8f1d98610fe..87a98e1d395d 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -796,3 +796,9 @@ def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload( quantization_config=self.quantization_config, torch_dtype=torch.float16 ) + + def test_torch_compile_with_group_offload(self): + torch._dynamo.config.capture_dynamic_output_shape_ops = True + super()._test_torch_compile_with_group_offload( + quantization_config=self.quantization_config, torch_dtype=torch.float16 + ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index ab764b14bed7..9940d8723895 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -63,3 +63,20 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype= for _ in range(2): # small resolutions to ensure speedy execution. pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) + + def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + pipe = self._init_pipeline(quantization_config, torch_dtype) + group_offload_kwargs = { + "onload_device": "cuda", + "offload_device": "cpu", + "offload_type": "block_level", + "num_blocks_per_group": 1, + "use_stream": True, + "non_blocking": True, + } + pipe.enable_group_offload(**group_offload_kwargs) + pipe.transformer.compile() + + for _ in range(2): + # small resolutions to ensure speedy execution. + pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) From d3010dd3859f026d204721073b69accdf5c61428 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 10:55:21 +0530 Subject: [PATCH 6/9] xfail --- tests/quantization/bnb/test_4bit.py | 5 +++-- tests/quantization/bnb/test_mixed_int8.py | 3 +-- tests/quantization/test_torch_compile_utils.py | 12 +++++++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 244c1d39d49a..3b4786f53758 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -878,9 +878,10 @@ def test_torch_compile(self): super()._test_torch_compile(quantization_config=self.quantization_config) def test_torch_compile_with_cpu_offload(self): - torch._dynamo.config.capture_dynamic_output_shape_ops = True super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) + @pytest.mark.xfail( + reason="Test fails because of an illegal memory access.", + ) def test_torch_compile_with_group_offload(self): - torch._dynamo.config.capture_dynamic_output_shape_ops = True super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 87a98e1d395d..592e3cc9422b 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -792,13 +792,12 @@ def test_torch_compile(self): super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16) def test_torch_compile_with_cpu_offload(self): - torch._dynamo.config.capture_dynamic_output_shape_ops = True super()._test_torch_compile_with_cpu_offload( quantization_config=self.quantization_config, torch_dtype=torch.float16 ) + @pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.") def test_torch_compile_with_group_offload(self): - torch._dynamo.config.capture_dynamic_output_shape_ops = True super()._test_torch_compile_with_group_offload( quantization_config=self.quantization_config, torch_dtype=torch.float16 ) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 9940d8723895..5859b8962008 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -65,17 +65,23 @@ def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype= pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256) def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16): + torch._dynamo.config.cache_size_limit = 10000 + pipe = self._init_pipeline(quantization_config, torch_dtype) group_offload_kwargs = { - "onload_device": "cuda", - "offload_device": "cpu", + "onload_device": torch.device("cuda"), + "offload_device": torch.device("cpu"), "offload_type": "block_level", "num_blocks_per_group": 1, "use_stream": True, "non_blocking": True, } - pipe.enable_group_offload(**group_offload_kwargs) + pipe.transformer.enable_group_offload(**group_offload_kwargs) pipe.transformer.compile() + for name, component in pipe.components.items(): + if name != "transformer" and isinstance(component, torch.nn.Module): + if torch.device(component.device).type == "cpu": + component.to("cuda") for _ in range(2): # small resolutions to ensure speedy execution. From af5707004ce74e319793e699e0979373d5e5eff9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 7 Jun 2025 11:29:15 +0530 Subject: [PATCH 7/9] update --- tests/quantization/bnb/test_4bit.py | 3 --- tests/quantization/test_torch_compile_utils.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 3b4786f53758..c16e0e611671 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -880,8 +880,5 @@ def test_torch_compile(self): def test_torch_compile_with_cpu_offload(self): super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config) - @pytest.mark.xfail( - reason="Test fails because of an illegal memory access.", - ) def test_torch_compile_with_group_offload(self): super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 5859b8962008..1acedb374052 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -71,8 +71,7 @@ def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtyp group_offload_kwargs = { "onload_device": torch.device("cuda"), "offload_device": torch.device("cpu"), - "offload_type": "block_level", - "num_blocks_per_group": 1, + "offload_type": "leaf_level", "use_stream": True, "non_blocking": True, } From 6f5df29e384699cb5c17315411669b92b5bed89a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 11 Jun 2025 18:52:13 +0530 Subject: [PATCH 8/9] Update tests/quantization/test_torch_compile_utils.py Co-authored-by: Dhruv Nair --- tests/quantization/test_torch_compile_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py index 1acedb374052..1ae77b27d7cd 100644 --- a/tests/quantization/test_torch_compile_utils.py +++ b/tests/quantization/test_torch_compile_utils.py @@ -23,7 +23,7 @@ @require_torch_gpu @slow -class QuantCompileMiscTests(unittest.TestCase): +class QuantCompileTests(unittest.TestCase): quantization_config = None def setUp(self): From d44a29d53a5500849a2f93f230507d120a371c2c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 11 Jun 2025 20:41:04 +0530 Subject: [PATCH 9/9] fixes --- tests/quantization/bnb/test_4bit.py | 4 ++-- tests/quantization/bnb/test_mixed_int8.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index c16e0e611671..2d8b9f698bfe 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -51,7 +51,7 @@ torch_device, ) -from ..test_torch_compile_utils import QuantCompileMiscTests +from ..test_torch_compile_utils import QuantCompileTests def get_some_linear_layer(model): @@ -862,7 +862,7 @@ def test_fp4_double_safe(self): @require_torch_version_greater("2.7.1") -class Bnb4BitCompileTests(QuantCompileMiscTests): +class Bnb4BitCompileTests(QuantCompileTests): quantization_config = PipelineQuantizationConfig( quant_backend="bitsandbytes_8bit", quant_kwargs={ diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 592e3cc9422b..f67a49e7bffe 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -49,7 +49,7 @@ torch_device, ) -from ..test_torch_compile_utils import QuantCompileMiscTests +from ..test_torch_compile_utils import QuantCompileTests def get_some_linear_layer(model): @@ -780,7 +780,7 @@ def test_serialization_sharded(self): @require_torch_version_greater_equal("2.6.0") -class Bnb8BitCompileTests(QuantCompileMiscTests): +class Bnb8BitCompileTests(QuantCompileTests): quantization_config = PipelineQuantizationConfig( quant_backend="bitsandbytes_8bit", quant_kwargs={"load_in_8bit": True},