|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2023 HuggingFace Inc. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +import os |
| 16 | +import tempfile |
| 17 | +import unittest |
| 18 | + |
| 19 | +import torch |
| 20 | +import torch.nn as nn |
| 21 | +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer |
| 22 | + |
| 23 | +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 24 | +from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin |
| 25 | +from diffusers.models.attention_processor import LoRAAttnProcessor |
| 26 | +from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device |
| 27 | + |
| 28 | + |
| 29 | +def create_unet_lora_layers(unet: nn.Module): |
| 30 | + lora_attn_procs = {} |
| 31 | + for name in unet.attn_processors.keys(): |
| 32 | + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim |
| 33 | + if name.startswith("mid_block"): |
| 34 | + hidden_size = unet.config.block_out_channels[-1] |
| 35 | + elif name.startswith("up_blocks"): |
| 36 | + block_id = int(name[len("up_blocks.")]) |
| 37 | + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
| 38 | + elif name.startswith("down_blocks"): |
| 39 | + block_id = int(name[len("down_blocks.")]) |
| 40 | + hidden_size = unet.config.block_out_channels[block_id] |
| 41 | + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) |
| 42 | + unet_lora_layers = AttnProcsLayers(lora_attn_procs) |
| 43 | + return lora_attn_procs, unet_lora_layers |
| 44 | + |
| 45 | + |
| 46 | +def create_text_encoder_lora_layers(text_encoder: nn.Module): |
| 47 | + text_lora_attn_procs = {} |
| 48 | + for name, module in text_encoder.named_modules(): |
| 49 | + if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]): |
| 50 | + text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) |
| 51 | + text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) |
| 52 | + return text_encoder_lora_layers |
| 53 | + |
| 54 | + |
| 55 | +class LoraLoaderMixinTests(unittest.TestCase): |
| 56 | + def get_dummy_components(self): |
| 57 | + torch.manual_seed(0) |
| 58 | + unet = UNet2DConditionModel( |
| 59 | + block_out_channels=(32, 64), |
| 60 | + layers_per_block=2, |
| 61 | + sample_size=32, |
| 62 | + in_channels=4, |
| 63 | + out_channels=4, |
| 64 | + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), |
| 65 | + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), |
| 66 | + cross_attention_dim=32, |
| 67 | + ) |
| 68 | + scheduler = DDIMScheduler( |
| 69 | + beta_start=0.00085, |
| 70 | + beta_end=0.012, |
| 71 | + beta_schedule="scaled_linear", |
| 72 | + clip_sample=False, |
| 73 | + set_alpha_to_one=False, |
| 74 | + ) |
| 75 | + torch.manual_seed(0) |
| 76 | + vae = AutoencoderKL( |
| 77 | + block_out_channels=[32, 64], |
| 78 | + in_channels=3, |
| 79 | + out_channels=3, |
| 80 | + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], |
| 81 | + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], |
| 82 | + latent_channels=4, |
| 83 | + ) |
| 84 | + text_encoder_config = CLIPTextConfig( |
| 85 | + bos_token_id=0, |
| 86 | + eos_token_id=2, |
| 87 | + hidden_size=32, |
| 88 | + intermediate_size=37, |
| 89 | + layer_norm_eps=1e-05, |
| 90 | + num_attention_heads=4, |
| 91 | + num_hidden_layers=5, |
| 92 | + pad_token_id=1, |
| 93 | + vocab_size=1000, |
| 94 | + ) |
| 95 | + text_encoder = CLIPTextModel(text_encoder_config) |
| 96 | + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") |
| 97 | + |
| 98 | + unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) |
| 99 | + text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder) |
| 100 | + |
| 101 | + pipeline_components = { |
| 102 | + "unet": unet, |
| 103 | + "scheduler": scheduler, |
| 104 | + "vae": vae, |
| 105 | + "text_encoder": text_encoder, |
| 106 | + "tokenizer": tokenizer, |
| 107 | + "safety_checker": None, |
| 108 | + "feature_extractor": None, |
| 109 | + } |
| 110 | + lora_components = { |
| 111 | + "unet_lora_layers": unet_lora_layers, |
| 112 | + "text_encoder_lora_layers": text_encoder_lora_layers, |
| 113 | + "unet_lora_attn_procs": unet_lora_attn_procs, |
| 114 | + } |
| 115 | + return pipeline_components, lora_components |
| 116 | + |
| 117 | + def get_dummy_inputs(self): |
| 118 | + batch_size = 1 |
| 119 | + sequence_length = 10 |
| 120 | + num_channels = 4 |
| 121 | + sizes = (32, 32) |
| 122 | + |
| 123 | + generator = torch.manual_seed(0) |
| 124 | + noise = floats_tensor((batch_size, num_channels) + sizes) |
| 125 | + input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) |
| 126 | + |
| 127 | + pipeline_inputs = { |
| 128 | + "prompt": "A painting of a squirrel eating a burger", |
| 129 | + "generator": generator, |
| 130 | + "num_inference_steps": 2, |
| 131 | + "guidance_scale": 6.0, |
| 132 | + "output_type": "numpy", |
| 133 | + } |
| 134 | + |
| 135 | + return noise, input_ids, pipeline_inputs |
| 136 | + |
| 137 | + def test_lora_save_load(self): |
| 138 | + pipeline_components, lora_components = self.get_dummy_components() |
| 139 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 140 | + sd_pipe = sd_pipe.to(torch_device) |
| 141 | + sd_pipe.set_progress_bar_config(disable=None) |
| 142 | + |
| 143 | + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() |
| 144 | + |
| 145 | + original_images = sd_pipe(**pipeline_inputs).images |
| 146 | + orig_image_slice = original_images[0, -3:, -3:, -1] |
| 147 | + |
| 148 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 149 | + LoraLoaderMixin.save_lora_weights( |
| 150 | + save_directory=tmpdirname, |
| 151 | + unet_lora_layers=lora_components["unet_lora_layers"], |
| 152 | + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], |
| 153 | + ) |
| 154 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) |
| 155 | + sd_pipe.load_lora_weights(tmpdirname) |
| 156 | + |
| 157 | + lora_images = sd_pipe(**pipeline_inputs).images |
| 158 | + lora_image_slice = lora_images[0, -3:, -3:, -1] |
| 159 | + |
| 160 | + # Outputs shouldn't match. |
| 161 | + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) |
| 162 | + |
| 163 | + def test_lora_save_load_safetensors(self): |
| 164 | + pipeline_components, lora_components = self.get_dummy_components() |
| 165 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 166 | + sd_pipe = sd_pipe.to(torch_device) |
| 167 | + sd_pipe.set_progress_bar_config(disable=None) |
| 168 | + |
| 169 | + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() |
| 170 | + |
| 171 | + original_images = sd_pipe(**pipeline_inputs).images |
| 172 | + orig_image_slice = original_images[0, -3:, -3:, -1] |
| 173 | + |
| 174 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 175 | + LoraLoaderMixin.save_lora_weights( |
| 176 | + save_directory=tmpdirname, |
| 177 | + unet_lora_layers=lora_components["unet_lora_layers"], |
| 178 | + text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], |
| 179 | + safe_serialization=True, |
| 180 | + ) |
| 181 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) |
| 182 | + sd_pipe.load_lora_weights(tmpdirname) |
| 183 | + |
| 184 | + lora_images = sd_pipe(**pipeline_inputs).images |
| 185 | + lora_image_slice = lora_images[0, -3:, -3:, -1] |
| 186 | + |
| 187 | + # Outputs shouldn't match. |
| 188 | + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) |
| 189 | + |
| 190 | + def test_lora_save_load_legacy(self): |
| 191 | + pipeline_components, lora_components = self.get_dummy_components() |
| 192 | + unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] |
| 193 | + sd_pipe = StableDiffusionPipeline(**pipeline_components) |
| 194 | + sd_pipe = sd_pipe.to(torch_device) |
| 195 | + sd_pipe.set_progress_bar_config(disable=None) |
| 196 | + |
| 197 | + noise, input_ids, pipeline_inputs = self.get_dummy_inputs() |
| 198 | + |
| 199 | + original_images = sd_pipe(**pipeline_inputs).images |
| 200 | + orig_image_slice = original_images[0, -3:, -3:, -1] |
| 201 | + |
| 202 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 203 | + unet = sd_pipe.unet |
| 204 | + unet.set_attn_processor(unet_lora_attn_procs) |
| 205 | + unet.save_attn_procs(tmpdirname) |
| 206 | + self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) |
| 207 | + sd_pipe.load_lora_weights(tmpdirname) |
| 208 | + |
| 209 | + lora_images = sd_pipe(**pipeline_inputs).images |
| 210 | + lora_image_slice = lora_images[0, -3:, -3:, -1] |
| 211 | + |
| 212 | + # Outputs shouldn't match. |
| 213 | + self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) |
0 commit comments