diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 7696852ecd44..9af0bb6d2280 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -525,6 +525,60 @@ jobs: pip install slack_sdk tabulate python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + run_nightly_pipeline_level_quantization_tests: + name: Torch quantization nightly tests + strategy: + fail-fast: false + max-parallel: 2 + runs-on: + group: aws-g6e-xlarge-plus + container: + image: diffusers/diffusers-pytorch-cuda + options: --shm-size "20gb" --ipc host --gpus 0 + steps: + - name: Checkout diffusers + uses: actions/checkout@v3 + with: + fetch-depth: 2 + - name: NVIDIA-SMI + run: nvidia-smi + - name: Install dependencies + run: | + python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH" + python -m uv pip install -e [quality,test] + python -m uv pip install -U bitsandbytes optimum_quanto + python -m uv pip install pytest-reportlog + - name: Environment + run: | + python utils/print_env.py + - name: Pipeline-level quantization tests on GPU + env: + HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }} + # https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms + CUBLAS_WORKSPACE_CONFIG: :16:8 + BIG_GPU_MEMORY: 40 + run: | + python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ + --make-reports=tests_pipeline_level_quant_torch_cuda \ + --report-log=tests_pipeline_level_quant_torch_cuda.log \ + tests/quantization/test_pipeline_level_quantization.py + - name: Failure short reports + if: ${{ failure() }} + run: | + cat reports/tests_pipeline_level_quant_torch_cuda_stats.txt + cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt + - name: Test suite reports artifacts + if: ${{ always() }} + uses: actions/upload-artifact@v4 + with: + name: torch_cuda_pipeline_level_quant_reports + path: reports + - name: Generate Report and Notify Channel + if: always() + run: | + pip install slack_sdk tabulate + python utils/log_reports.py >> $GITHUB_STEP_SUMMARY + # M1 runner currently not well supported # TODO: (Dhruv) add these back when we setup better testing for Apple Silicon # run_nightly_tests_apple_m1: diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md index 2c728cff3c07..e2ca990190e6 100644 --- a/docs/source/en/api/quantization.md +++ b/docs/source/en/api/quantization.md @@ -13,9 +13,7 @@ specific language governing permissions and limitations under the License. # Quantization -Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [bitsandbytes](https://huggingface.co/docs/bitsandbytes/en/index). - -Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class. +Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. @@ -23,6 +21,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui +## PipelineQuantizationConfig + +[[autodoc]] quantizers.PipelineQuantizationConfig ## BitsAndBytesConfig diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md index 93323f86c7fc..68b99f524ec0 100644 --- a/docs/source/en/quantization/overview.md +++ b/docs/source/en/quantization/overview.md @@ -39,3 +39,90 @@ Diffusers currently supports the following quantization methods. - [Quanto](./quanto.md) [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques. + +## Pipeline-level quantization + +Diffusers allows users to directly initialize pipelines from checkpoints that may contain quantized models ([example](https://huggingface.co/hf-internal-testing/flux.1-dev-nf4-pkg)). However, users may want to apply +quantization on-the-fly when initializing a pipeline from a pre-trained and non-quantized checkpoint. You can +do this with [`~quantizers.PipelineQuantizationConfig`]. + +Start by defining a `PipelineQuantizationConfig`: + +```py +import torch +from diffusers import DiffusionPipeline +from diffusers.quantizers.quantization_config import QuantoConfig +from diffusers.quantizers import PipelineQuantizationConfig +from transformers import BitsAndBytesConfig + +pipeline_quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": BitsAndBytesConfig( + load_in_4bit=True, compute_dtype=torch.bfloat16 + ), + } +) +``` + +Then pass it to [`~DiffusionPipeline.from_pretrained`] and run inference: + +```py +pipe = DiffusionPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + quantization_config=pipeline_quant_config, + torch_dtype=torch.bfloat16, +).to("cuda") + +image = pipe("photo of a cute dog").images[0] +``` + +This method allows for more granular control over the quantization specifications of individual +model-level components of a pipeline. It also allows for different quantization backends for +different components. In the above example, you used a combination of Quanto and BitsandBytes. However, +one caveat of this method is that users need to know which components come from `transformers` to be able +to import the right quantization config class. + +The other method is simpler in terms of experience but is +less-flexible. Start by defining a `PipelineQuantizationConfig` but in a different way: + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer", "text_encoder_2"], +) +``` + +This `pipeline_quant_config` can now be passed to [`~DiffusionPipeline.from_pretrained`] similar to the above example. + +In this case, `quant_kwargs` will be used to initialize the quantization specifications +of the respective quantization configuration class of `quant_backend`. `components_to_quantize` +is used to denote the components that will be quantized. For most pipelines, you would want to +keep `transformer` in the list as that is often the most compute and memory intensive. + +The config below will work for most diffusion pipelines that have a `transformer` component present. +In most case, you will want to quantize the `transformer` component as that is often the most compute- +intensive part of a diffusion pipeline. + +```py +pipeline_quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16}, + components_to_quantize=["transformer"], +) +``` + +Below is a list of the supported quantization backends available in both `diffusers` and `transformers`: + +* `bitsandbytes_4bit` +* `bitsandbytes_8bit` +* `gguf` +* `quanto` +* `torchao` + + +Diffusion pipelines can have multiple text encoders. [`FluxPipeline`] has two, for example. It's +recommended to quantize the text encoders that are memory-intensive. Some examples include T5, +Llama, Gemma, etc. In the above example, you quantized the T5 model of [`FluxPipeline`] through +`text_encoder_2` while keeping the CLIP model intact (accessible through `text_encoder`). \ No newline at end of file diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 9788d758e9bc..3404ae5130fe 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -675,8 +675,10 @@ def load_sub_model( use_safetensors: bool, dduf_entries: Optional[Dict[str, DDUFEntry]], provider_options: Any, + quantization_config: Optional[Any] = None, ): """Helper method to load the module `name` from `library_name` and `class_name`""" + from ..quantizers import PipelineQuantizationConfig # retrieve class candidates @@ -769,6 +771,17 @@ def load_sub_model( else: loading_kwargs["low_cpu_mem_usage"] = False + if ( + quantization_config is not None + and isinstance(quantization_config, PipelineQuantizationConfig) + and issubclass(class_obj, torch.nn.Module) + ): + model_quant_config = quantization_config._resolve_quant_config( + is_diffusers=is_diffusers_model, module_name=name + ) + if model_quant_config is not None: + loading_kwargs["quantization_config"] = model_quant_config + # check if the module is in a subdirectory if dduf_entries: loading_kwargs["dduf_entries"] = dduf_entries diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 3be3e46ca44c..7cb2a12d3c94 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -47,6 +47,7 @@ from ..models import AutoencoderKL from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin +from ..quantizers import PipelineQuantizationConfig from ..quantizers.bitsandbytes.utils import _check_bnb_status from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..utils import ( @@ -725,6 +726,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + quantization_config = kwargs.pop("quantization_config", None) if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype): torch_dtype = torch.float32 @@ -741,6 +743,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " install accelerate\n```\n." ) + if quantization_config is not None and not isinstance(quantization_config, PipelineQuantizationConfig): + raise ValueError("`quantization_config` must be an instance of `PipelineQuantizationConfig`.") + if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): raise NotImplementedError( "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" @@ -1001,6 +1006,7 @@ def load_module(name, value): use_safetensors=use_safetensors, dduf_entries=dduf_entries, provider_options=provider_options, + quantization_config=quantization_config, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py index 4c8483a3d6ee..bd9e2303c93b 100644 --- a/src/diffusers/quantizers/__init__.py +++ b/src/diffusers/quantizers/__init__.py @@ -12,5 +12,183 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect +from typing import Dict, List, Optional, Union + +from ..utils import is_transformers_available, logging from .auto import DiffusersAutoQuantizer from .base import DiffusersQuantizer +from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin + + +try: + from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin +except ImportError: + + class TransformersQuantConfigMixin: + pass + + +logger = logging.get_logger(__name__) + + +class PipelineQuantizationConfig: + """ + Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`]. + + Args: + quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend + is available to both `diffusers` and `transformers`. + quant_kwargs (`dict`): Params to initialize the quantization backend class. + components_to_quantize (`list`): Components of a pipeline to be quantized. + quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline + components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`, + and `components_to_quantize`. + """ + + def __init__( + self, + quant_backend: str = None, + quant_kwargs: Dict[str, Union[str, float, int, dict]] = None, + components_to_quantize: Optional[List[str]] = None, + quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None, + ): + self.quant_backend = quant_backend + # Initialize kwargs to be {} to set to the defaults. + self.quant_kwargs = quant_kwargs or {} + self.components_to_quantize = components_to_quantize + self.quant_mapping = quant_mapping + + self.post_init() + + def post_init(self): + quant_mapping = self.quant_mapping + self.is_granular = True if quant_mapping is not None else False + + self._validate_init_args() + + def _validate_init_args(self): + if self.quant_backend and self.quant_mapping: + raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.") + + if not self.quant_mapping and not self.quant_backend: + raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.") + + if not self.quant_kwargs and not self.quant_mapping: + raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.") + + if self.quant_backend is not None: + self._validate_init_kwargs_in_backends() + + if self.quant_mapping is not None: + self._validate_quant_mapping_args() + + def _validate_init_kwargs_in_backends(self): + quant_backend = self.quant_backend + + self._check_backend_availability(quant_backend) + + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + if quant_config_mapping_transformers is not None: + init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__) + init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"} + else: + init_kwargs_transformers = None + + init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__) + init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"} + + if init_kwargs_transformers != init_kwargs_diffusers: + raise ValueError( + "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. " + f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how " + "this mapping would look like." + ) + + def _validate_quant_mapping_args(self): + quant_mapping = self.quant_mapping + transformers_map, diffusers_map = self._get_quant_config_list() + + available_transformers = list(transformers_map.values()) if transformers_map else None + available_diffusers = list(diffusers_map.values()) + + for module_name, config in quant_mapping.items(): + if any(isinstance(config, cfg) for cfg in available_diffusers): + continue + + if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers): + continue + + if available_transformers: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}; " + f"Available transformers configs: {available_transformers}." + ) + else: + raise ValueError( + f"Provided config for module_name={module_name} could not be found. " + f"Available diffusers configs: {available_diffusers}." + ) + + def _check_backend_availability(self, quant_backend: str): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + available_backends_transformers = ( + list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None + ) + available_backends_diffusers = list(quant_config_mapping_diffusers.keys()) + + if ( + available_backends_transformers and quant_backend not in available_backends_transformers + ) or quant_backend not in quant_config_mapping_diffusers: + error_message = f"Provided quant_backend={quant_backend} was not found." + if available_backends_transformers: + error_message += f"\nAvailable ones (transformers): {available_backends_transformers}." + error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}." + raise ValueError(error_message) + + def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None): + quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list() + + quant_mapping = self.quant_mapping + components_to_quantize = self.components_to_quantize + + # Granular case + if self.is_granular and module_name in quant_mapping: + logger.debug(f"Initializing quantization config class for {module_name}.") + config = quant_mapping[module_name] + return config + + # Global config case + else: + should_quantize = False + # Only quantize the modules requested for. + if components_to_quantize and module_name in components_to_quantize: + should_quantize = True + # No specification for `components_to_quantize` means all modules should be quantized. + elif not self.is_granular and not components_to_quantize: + should_quantize = True + + if should_quantize: + logger.debug(f"Initializing quantization config class for {module_name}.") + mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers + quant_config_cls = mapping_to_use[self.quant_backend] + quant_kwargs = self.quant_kwargs + return quant_config_cls(**quant_kwargs) + + # Fallback: no applicable configuration found. + return None + + def _get_quant_config_list(self): + if is_transformers_available(): + from transformers.quantizers.auto import ( + AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers, + ) + else: + quant_config_mapping_transformers = None + + from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers + + return quant_config_mapping_transformers, quant_config_mapping_diffusers diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index cc4d4fc1b017..7fec68642f01 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -75,7 +75,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): Args: config_dict (`Dict[str, Any]`): Dictionary that will be used to instantiate the configuration object. - return_unused_kwargs (`bool`,*optional*, defaults to `False`): + return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in `PreTrainedModel`. kwargs (`Dict[str, Any]`): diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7a524e76f16e..00aad9d71a61 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -38,6 +38,7 @@ is_note_seq_available, is_onnx_available, is_opencv_available, + is_optimum_quanto_available, is_peft_available, is_timm_available, is_torch_available, @@ -486,6 +487,13 @@ def require_bitsandbytes(test_case): return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) +def require_quanto(test_case): + """ + Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed. + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case) + + def require_accelerate(test_case): """ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py new file mode 100644 index 000000000000..b82b2889d72d --- /dev/null +++ b/tests/quantization/test_pipeline_level_quantization.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Team Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a clone of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +import torch + +from diffusers import DiffusionPipeline, QuantoConfig +from diffusers.quantizers import PipelineQuantizationConfig +from diffusers.utils.testing_utils import ( + is_transformers_available, + require_accelerate, + require_bitsandbytes_version_greater, + require_quanto, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) + + +if is_transformers_available(): + from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig +else: + TranBitsAndBytesConfig = None + + +@require_bitsandbytes_version_greater("0.43.2") +@require_quanto +@require_accelerate +@require_torch +@require_torch_accelerator +@slow +class PipelineQuantizationTests(unittest.TestCase): + model_name = "hf-internal-testing/tiny-flux-pipe" + prompt = "a beautiful sunset amidst the mountains." + num_inference_steps = 10 + seed = 0 + + def test_quant_config_set_correctly_through_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + quant_config = PipelineQuantizationConfig( + quant_backend="bitsandbytes_4bit", + quant_kwargs={ + "load_in_4bit": True, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": torch.bfloat16, + }, + components_to_quantize=components_to_quantize, + ) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_quant_config_set_correctly_through_granular(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + for name, component in pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + _ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) + + def test_raises_error_for_invalid_config(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + }, + quant_backend="bitsandbytes_4bit", + ) + + self.assertTrue( + str(err_context.exception) + == "Both `quant_backend` and `quant_mapping` cannot be specified at the same time." + ) + + def test_validation_for_kwargs(self): + components_to_quantize = ["transformer", "text_encoder_2"] + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_backend="quanto", + quant_kwargs={"weights_dtype": "int8"}, + components_to_quantize=components_to_quantize, + ) + + self.assertTrue( + "The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) + ) + + def test_raises_error_for_wrong_config_class(self): + quant_config = { + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + with self.assertRaises(ValueError) as err_context: + _ = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ) + self.assertTrue( + str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`." + ) + + def test_validation_for_mapping(self): + with self.assertRaises(ValueError) as err_context: + _ = PipelineQuantizationConfig( + quant_mapping={ + "transformer": DiffusionPipeline(), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + + self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception)) + + def test_saving_loading(self): + quant_config = PipelineQuantizationConfig( + quant_mapping={ + "transformer": QuantoConfig(weights_dtype="int8"), + "text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), + } + ) + components_to_quantize = list(quant_config.quant_mapping.keys()) + pipe = DiffusionPipeline.from_pretrained( + self.model_name, + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + ).to(torch_device) + + pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"} + output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device) + for name, component in loaded_pipe.components.items(): + if name in components_to_quantize: + self.assertTrue(getattr(component.config, "quantization_config", None) is not None) + quantization_config = component.config.quantization_config + + if name == "text_encoder_2": + self.assertTrue(quantization_config.load_in_4bit) + self.assertTrue(quantization_config.quant_method == "bitsandbytes") + else: + self.assertTrue(quantization_config.quant_method == "quanto") + + output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images + + self.assertTrue(torch.allclose(output_1, output_2))