Skip to content

Commit 5b6582c

Browse files
[Model offload] Add nice warning (#2543)
* [Model offload] Add nice warning * Treat sequential and model offload differently. Sequential raises an error because the operation would fail with a cryptic warning later. * Forcibly move to cpu when offloading. * make style * one more fix * make fix-copies * up --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent 4f0141a commit 5b6582c

15 files changed

+166
-2
lines changed

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
214214

215215
device = torch.device(f"cuda:{gpu_id}")
216216

217+
if self.device.type != "cpu":
218+
self.to("cpu", silence_dtype_warnings=True)
219+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
220+
217221
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
218222
cpu_offload(cpu_offloaded_model, device)
219223

@@ -234,6 +238,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
234238

235239
device = torch.device(f"cuda:{gpu_id}")
236240

241+
if self.device.type != "cpu":
242+
self.to("cpu", silence_dtype_warnings=True)
243+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
244+
237245
hook = None
238246
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
239247
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
220220

221221
device = torch.device(f"cuda:{gpu_id}")
222222

223+
if self.device.type != "cpu":
224+
self.to("cpu", silence_dtype_warnings=True)
225+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
226+
223227
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
224228
cpu_offload(cpu_offloaded_model, device)
225229

@@ -240,6 +244,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
240244

241245
device = torch.device(f"cuda:{gpu_id}")
242246

247+
if self.device.type != "cpu":
248+
self.to("cpu", silence_dtype_warnings=True)
249+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
250+
243251
hook = None
244252
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
245253
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
get_class_from_dynamic_module,
5050
http_user_agent,
5151
is_accelerate_available,
52+
is_accelerate_version,
5253
is_safetensors_available,
5354
is_torch_version,
5455
is_transformers_available,
@@ -66,6 +67,10 @@
6667
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
6768

6869

70+
if is_accelerate_available():
71+
import accelerate
72+
73+
6974
INDEX_FILE = "diffusion_pytorch_model.bin"
7075
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
7176
DUMMY_MODULES_FOLDER = "diffusers.utils"
@@ -337,15 +342,50 @@ def is_saveable_module(name, value):
337342

338343
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
339344

340-
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
345+
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
341346
if torch_device is None:
342347
return self
343348

349+
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
350+
def module_is_sequentially_offloaded(module):
351+
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
352+
return False
353+
354+
return hasattr(module, "_hf_hook") and not isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
355+
356+
def module_is_offloaded(module):
357+
if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
358+
return False
359+
360+
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
361+
362+
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
363+
pipeline_is_sequentially_offloaded = any(
364+
module_is_sequentially_offloaded(module) for _, module in self.components.items()
365+
)
366+
if pipeline_is_sequentially_offloaded and torch.device(torch_device).type == "cuda":
367+
raise ValueError(
368+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
369+
)
370+
371+
# Display a warning in this case (the operation succeeds but the benefits are lost)
372+
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
373+
if pipeline_is_offloaded and torch.device(torch_device).type == "cuda":
374+
logger.warning(
375+
f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
376+
)
377+
344378
module_names, _, _ = self.extract_init_dict(dict(self.config))
379+
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
345380
for name in module_names.keys():
346381
module = getattr(self, name)
347382
if isinstance(module, torch.nn.Module):
348-
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
383+
if (
384+
module.dtype == torch.float16
385+
and str(torch_device) in ["cpu"]
386+
and not silence_dtype_warnings
387+
and not is_offloaded
388+
):
349389
logger.warning(
350390
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
351391
" is not recommended to move them to `cpu` as running them will fail. Please make"

src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
237237

238238
device = torch.device(f"cuda:{gpu_id}")
239239

240+
if self.device.type != "cpu":
241+
self.to("cpu", silence_dtype_warnings=True)
242+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
243+
240244
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
241245
cpu_offload(cpu_offloaded_model, device)
242246

@@ -258,6 +262,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
258262

259263
device = torch.device(f"cuda:{gpu_id}")
260264

265+
if self.device.type != "cpu":
266+
self.to("cpu", silence_dtype_warnings=True)
267+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
268+
261269
hook = None
262270
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
263271
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
217217

218218
device = torch.device(f"cuda:{gpu_id}")
219219

220+
if self.device.type != "cpu":
221+
self.to("cpu", silence_dtype_warnings=True)
222+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
223+
220224
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
221225
cpu_offload(cpu_offloaded_model, device)
222226

@@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
237241

238242
device = torch.device(f"cuda:{gpu_id}")
239243

244+
if self.device.type != "cpu":
245+
self.to("cpu", silence_dtype_warnings=True)
246+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
247+
240248
hook = None
241249
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
242250
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
263263

264264
device = torch.device(f"cuda:{gpu_id}")
265265

266+
if self.device.type != "cpu":
267+
self.to("cpu", silence_dtype_warnings=True)
268+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
269+
266270
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
267271
cpu_offload(cpu_offloaded_model, device)
268272

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
225225

226226
device = torch.device(f"cuda:{gpu_id}")
227227

228+
if self.device.type != "cpu":
229+
self.to("cpu", silence_dtype_warnings=True)
230+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
231+
228232
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
229233
cpu_offload(cpu_offloaded_model, device)
230234

@@ -246,6 +250,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
246250

247251
device = torch.device(f"cuda:{gpu_id}")
248252

253+
if self.device.type != "cpu":
254+
self.to("cpu", silence_dtype_warnings=True)
255+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
256+
249257
hook = None
250258
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
251259
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
272272

273273
device = torch.device(f"cuda:{gpu_id}")
274274

275+
if self.device.type != "cpu":
276+
self.to("cpu", silence_dtype_warnings=True)
277+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
278+
275279
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
276280
cpu_offload(cpu_offloaded_model, device)
277281

@@ -293,6 +297,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
293297

294298
device = torch.device(f"cuda:{gpu_id}")
295299

300+
if self.device.type != "cpu":
301+
self.to("cpu", silence_dtype_warnings=True)
302+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
303+
296304
hook = None
297305
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
298306
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
216216

217217
device = torch.device(f"cuda:{gpu_id}")
218218

219+
if self.device.type != "cpu":
220+
self.to("cpu", silence_dtype_warnings=True)
221+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
222+
219223
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
220224
cpu_offload(cpu_offloaded_model, device)
221225

@@ -237,6 +241,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
237241

238242
device = torch.device(f"cuda:{gpu_id}")
239243

244+
if self.device.type != "cpu":
245+
self.to("cpu", silence_dtype_warnings=True)
246+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
247+
240248
hook = None
241249
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
242250
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
405405

406406
device = torch.device(f"cuda:{gpu_id}")
407407

408+
if self.device.type != "cpu":
409+
self.to("cpu", silence_dtype_warnings=True)
410+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
411+
408412
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
409413
cpu_offload(cpu_offloaded_model, device)
410414

@@ -426,6 +430,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
426430

427431
device = torch.device(f"cuda:{gpu_id}")
428432

433+
if self.device.type != "cpu":
434+
self.to("cpu", silence_dtype_warnings=True)
435+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
436+
429437
hook = None
430438
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
431439
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
137137

138138
device = torch.device(f"cuda:{gpu_id}")
139139

140+
if self.device.type != "cpu":
141+
self.to("cpu", silence_dtype_warnings=True)
142+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
143+
140144
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
141145
cpu_offload(cpu_offloaded_model, device)
142146

@@ -158,6 +162,10 @@ def enable_model_cpu_offload(self, gpu_id=0):
158162

159163
device = torch.device(f"cuda:{gpu_id}")
160164

165+
if self.device.type != "cpu":
166+
self.to("cpu", silence_dtype_warnings=True)
167+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
168+
161169
hook = None
162170
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
163171
_, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
158158

159159
device = torch.device(f"cuda:{gpu_id}")
160160

161+
if self.device.type != "cpu":
162+
self.to("cpu", silence_dtype_warnings=True)
163+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
164+
161165
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
162166
cpu_offload(cpu_offloaded_model, device)
163167

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
372372

373373
device = torch.device(f"cuda:{gpu_id}")
374374

375+
if self.device.type != "cpu":
376+
self.to("cpu", silence_dtype_warnings=True)
377+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
378+
375379
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
376380
cpu_offload(cpu_offloaded_model, device)
377381

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
176176

177177
device = torch.device(f"cuda:{gpu_id}")
178178

179+
if self.device.type != "cpu":
180+
self.to("cpu", silence_dtype_warnings=True)
181+
torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
182+
179183
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
180184
cpu_offload(cpu_offloaded_model, device)
181185

tests/test_pipelines.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,42 @@ def test_stable_diffusion_components(self):
584584
assert image_img2img.shape == (1, 32, 32, 3)
585585
assert image_text2img.shape == (1, 64, 64, 3)
586586

587+
@require_torch_gpu
588+
def test_pipe_false_offload_warn(self):
589+
unet = self.dummy_cond_unet()
590+
scheduler = PNDMScheduler(skip_prk_steps=True)
591+
vae = self.dummy_vae
592+
bert = self.dummy_text_encoder
593+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
594+
595+
sd = StableDiffusionPipeline(
596+
unet=unet,
597+
scheduler=scheduler,
598+
vae=vae,
599+
text_encoder=bert,
600+
tokenizer=tokenizer,
601+
safety_checker=None,
602+
feature_extractor=self.dummy_extractor,
603+
)
604+
605+
sd.enable_model_cpu_offload()
606+
607+
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
608+
with CaptureLogger(logger) as cap_logger:
609+
sd.to("cuda")
610+
611+
assert "It is strongly recommended against doing so" in str(cap_logger)
612+
613+
sd = StableDiffusionPipeline(
614+
unet=unet,
615+
scheduler=scheduler,
616+
vae=vae,
617+
text_encoder=bert,
618+
tokenizer=tokenizer,
619+
safety_checker=None,
620+
feature_extractor=self.dummy_extractor,
621+
)
622+
587623
def test_set_scheduler(self):
588624
unet = self.dummy_cond_unet()
589625
scheduler = PNDMScheduler(skip_prk_steps=True)

0 commit comments

Comments
 (0)