Skip to content

Commit 34d14d7

Browse files
[MultiControlNet] Allow save and load (#3747)
* [MultiControlNet] Allow save and load * Correct more * [MultiControlNet] Allow save and load * make style * Apply suggestions from code review
1 parent ef95907 commit 34d14d7

File tree

7 files changed

+123
-88
lines changed

7 files changed

+123
-88
lines changed

src/diffusers/pipelines/controlnet/multicontrolnet.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
from typing import Any, Dict, List, Optional, Tuple, Union
1+
import os
2+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
23

34
import torch
45
from torch import nn
56

67
from ...models.controlnet import ControlNetModel, ControlNetOutput
78
from ...models.modeling_utils import ModelMixin
9+
from ...utils import logging
10+
11+
12+
logger = logging.get_logger(__name__)
813

914

1015
class MultiControlNetModel(ModelMixin):
@@ -64,3 +69,117 @@ def forward(
6469
mid_block_res_sample += mid_sample
6570

6671
return down_block_res_samples, mid_block_res_sample
72+
73+
def save_pretrained(
74+
self,
75+
save_directory: Union[str, os.PathLike],
76+
is_main_process: bool = True,
77+
save_function: Callable = None,
78+
safe_serialization: bool = False,
79+
variant: Optional[str] = None,
80+
):
81+
"""
82+
Save a model and its configuration file to a directory, so that it can be re-loaded using the
83+
`[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method.
84+
85+
Arguments:
86+
save_directory (`str` or `os.PathLike`):
87+
Directory to which to save. Will be created if it doesn't exist.
88+
is_main_process (`bool`, *optional*, defaults to `True`):
89+
Whether the process calling this is the main process or not. Useful when in distributed training like
90+
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
91+
the main process to avoid race conditions.
92+
save_function (`Callable`):
93+
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
94+
need to replace `torch.save` by another method. Can be configured with the environment variable
95+
`DIFFUSERS_SAVE_MODE`.
96+
safe_serialization (`bool`, *optional*, defaults to `False`):
97+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
98+
variant (`str`, *optional*):
99+
If specified, weights are saved in the format pytorch_model.<variant>.bin.
100+
"""
101+
idx = 0
102+
model_path_to_save = save_directory
103+
for controlnet in self.nets:
104+
controlnet.save_pretrained(
105+
model_path_to_save,
106+
is_main_process=is_main_process,
107+
save_function=save_function,
108+
safe_serialization=safe_serialization,
109+
variant=variant,
110+
)
111+
112+
idx += 1
113+
model_path_to_save = model_path_to_save + f"_{idx}"
114+
115+
@classmethod
116+
def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
117+
r"""
118+
Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
119+
120+
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
121+
the model, you should first set it back in training mode with `model.train()`.
122+
123+
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
124+
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
125+
task.
126+
127+
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
128+
weights are discarded.
129+
130+
Parameters:
131+
pretrained_model_path (`os.PathLike`):
132+
A path to a *directory* containing model weights saved using
133+
[`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g.,
134+
`./my_model_directory/controlnet`.
135+
torch_dtype (`str` or `torch.dtype`, *optional*):
136+
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
137+
will be automatically derived from the model's weights.
138+
output_loading_info(`bool`, *optional*, defaults to `False`):
139+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
140+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
141+
A map that specifies where each submodule should go. It doesn't need to be refined to each
142+
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
143+
same device.
144+
145+
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
146+
more information about each option see [designing a device
147+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
148+
max_memory (`Dict`, *optional*):
149+
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
150+
GPU and the available CPU RAM if unset.
151+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
152+
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
153+
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
154+
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
155+
setting this argument to `True` will raise an error.
156+
variant (`str`, *optional*):
157+
If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
158+
ignored when using `from_flax`.
159+
use_safetensors (`bool`, *optional*, defaults to `None`):
160+
If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
161+
`safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
162+
`safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
163+
"""
164+
idx = 0
165+
controlnets = []
166+
167+
# load controlnet and append to list until no controlnet directory exists anymore
168+
# first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
169+
# second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
170+
model_path_to_load = pretrained_model_path
171+
while os.path.isdir(model_path_to_load):
172+
controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
173+
controlnets.append(controlnet)
174+
175+
idx += 1
176+
model_path_to_load = pretrained_model_path + f"_{idx}"
177+
178+
logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
179+
180+
if len(controlnets) == 0:
181+
raise ValueError(
182+
f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
183+
)
184+
185+
return cls(controlnets)

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
import inspect
17-
import os
1817
import warnings
1918
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2019

@@ -560,7 +559,7 @@ def check_inputs(
560559
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
561560
elif len(image) != len(self.controlnet.nets):
562561
raise ValueError(
563-
"For multiple controlnets: `image` must have the same length as the number of controlnets."
562+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
564563
)
565564

566565
for image_ in image:
@@ -679,18 +678,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
679678
latents = latents * self.scheduler.init_noise_sigma
680679
return latents
681680

682-
# override DiffusionPipeline
683-
def save_pretrained(
684-
self,
685-
save_directory: Union[str, os.PathLike],
686-
safe_serialization: bool = False,
687-
variant: Optional[str] = None,
688-
):
689-
if isinstance(self.controlnet, ControlNetModel):
690-
super().save_pretrained(save_directory, safe_serialization, variant)
691-
else:
692-
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
693-
694681
@torch.no_grad()
695682
@replace_example_docstring(EXAMPLE_DOC_STRING)
696683
def __call__(

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
import inspect
17-
import os
1817
import warnings
1918
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2019

@@ -586,7 +585,7 @@ def check_inputs(
586585
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
587586
elif len(image) != len(self.controlnet.nets):
588587
raise ValueError(
589-
"For multiple controlnets: `image` must have the same length as the number of controlnets."
588+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
590589
)
591590

592591
for image_ in image:
@@ -757,18 +756,6 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
757756

758757
return latents
759758

760-
# override DiffusionPipeline
761-
def save_pretrained(
762-
self,
763-
save_directory: Union[str, os.PathLike],
764-
safe_serialization: bool = False,
765-
variant: Optional[str] = None,
766-
):
767-
if isinstance(self.controlnet, ControlNetModel):
768-
super().save_pretrained(save_directory, safe_serialization, variant)
769-
else:
770-
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
771-
772759
@torch.no_grad()
773760
@replace_example_docstring(EXAMPLE_DOC_STRING)
774761
def __call__(

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
1616

1717
import inspect
18-
import os
1918
import warnings
2019
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2120

@@ -718,7 +717,7 @@ def check_inputs(
718717
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
719718
elif len(image) != len(self.controlnet.nets):
720719
raise ValueError(
721-
"For multiple controlnets: `image` must have the same length as the number of controlnets."
720+
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
722721
)
723722

724723
for image_ in image:
@@ -957,18 +956,6 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
957956

958957
return image_latents
959958

960-
# override DiffusionPipeline
961-
def save_pretrained(
962-
self,
963-
save_directory: Union[str, os.PathLike],
964-
safe_serialization: bool = False,
965-
variant: Optional[str] = None,
966-
):
967-
if isinstance(self.controlnet, ControlNetModel):
968-
super().save_pretrained(save_directory, safe_serialization, variant)
969-
else:
970-
raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
971-
972959
@torch.no_grad()
973960
@replace_example_docstring(EXAMPLE_DOC_STRING)
974961
def __call__(

tests/pipelines/controlnet/test_controlnet.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -346,21 +346,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
346346
except NotImplementedError:
347347
pass
348348

349-
# override PipelineTesterMixin
350-
@unittest.skip("save pretrained not implemented")
351-
def test_save_load_float16(self):
352-
...
353-
354-
# override PipelineTesterMixin
355-
@unittest.skip("save pretrained not implemented")
356-
def test_save_load_local(self):
357-
...
358-
359-
# override PipelineTesterMixin
360-
@unittest.skip("save pretrained not implemented")
361-
def test_save_load_optional_components(self):
362-
...
363-
364349

365350
@slow
366351
@require_torch_gpu

tests/pipelines/controlnet/test_controlnet_img2img.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -304,21 +304,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
304304
except NotImplementedError:
305305
pass
306306

307-
# override PipelineTesterMixin
308-
@unittest.skip("save pretrained not implemented")
309-
def test_save_load_float16(self):
310-
...
311-
312-
# override PipelineTesterMixin
313-
@unittest.skip("save pretrained not implemented")
314-
def test_save_load_local(self):
315-
...
316-
317-
# override PipelineTesterMixin
318-
@unittest.skip("save pretrained not implemented")
319-
def test_save_load_optional_components(self):
320-
...
321-
322307

323308
@slow
324309
@require_torch_gpu

tests/pipelines/controlnet/test_controlnet_inpaint.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -382,21 +382,6 @@ def test_save_pretrained_raise_not_implemented_exception(self):
382382
except NotImplementedError:
383383
pass
384384

385-
# override PipelineTesterMixin
386-
@unittest.skip("save pretrained not implemented")
387-
def test_save_load_float16(self):
388-
...
389-
390-
# override PipelineTesterMixin
391-
@unittest.skip("save pretrained not implemented")
392-
def test_save_load_local(self):
393-
...
394-
395-
# override PipelineTesterMixin
396-
@unittest.skip("save pretrained not implemented")
397-
def test_save_load_optional_components(self):
398-
...
399-
400385

401386
@slow
402387
@require_torch_gpu

0 commit comments

Comments
 (0)