|
1 |
| -from typing import Any, Dict, List, Optional, Tuple, Union |
| 1 | +import os |
| 2 | +from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
2 | 3 |
|
3 | 4 | import torch
|
4 | 5 | from torch import nn
|
5 | 6 |
|
6 | 7 | from ...models.controlnet import ControlNetModel, ControlNetOutput
|
7 | 8 | from ...models.modeling_utils import ModelMixin
|
| 9 | +from ...utils import logging |
| 10 | + |
| 11 | + |
| 12 | +logger = logging.get_logger(__name__) |
8 | 13 |
|
9 | 14 |
|
10 | 15 | class MultiControlNetModel(ModelMixin):
|
@@ -64,3 +69,117 @@ def forward(
|
64 | 69 | mid_block_res_sample += mid_sample
|
65 | 70 |
|
66 | 71 | return down_block_res_samples, mid_block_res_sample
|
| 72 | + |
| 73 | + def save_pretrained( |
| 74 | + self, |
| 75 | + save_directory: Union[str, os.PathLike], |
| 76 | + is_main_process: bool = True, |
| 77 | + save_function: Callable = None, |
| 78 | + safe_serialization: bool = False, |
| 79 | + variant: Optional[str] = None, |
| 80 | + ): |
| 81 | + """ |
| 82 | + Save a model and its configuration file to a directory, so that it can be re-loaded using the |
| 83 | + `[`~pipelines.controlnet.MultiControlNetModel.from_pretrained`]` class method. |
| 84 | +
|
| 85 | + Arguments: |
| 86 | + save_directory (`str` or `os.PathLike`): |
| 87 | + Directory to which to save. Will be created if it doesn't exist. |
| 88 | + is_main_process (`bool`, *optional*, defaults to `True`): |
| 89 | + Whether the process calling this is the main process or not. Useful when in distributed training like |
| 90 | + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on |
| 91 | + the main process to avoid race conditions. |
| 92 | + save_function (`Callable`): |
| 93 | + The function to use to save the state dictionary. Useful on distributed training like TPUs when one |
| 94 | + need to replace `torch.save` by another method. Can be configured with the environment variable |
| 95 | + `DIFFUSERS_SAVE_MODE`. |
| 96 | + safe_serialization (`bool`, *optional*, defaults to `False`): |
| 97 | + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). |
| 98 | + variant (`str`, *optional*): |
| 99 | + If specified, weights are saved in the format pytorch_model.<variant>.bin. |
| 100 | + """ |
| 101 | + idx = 0 |
| 102 | + model_path_to_save = save_directory |
| 103 | + for controlnet in self.nets: |
| 104 | + controlnet.save_pretrained( |
| 105 | + model_path_to_save, |
| 106 | + is_main_process=is_main_process, |
| 107 | + save_function=save_function, |
| 108 | + safe_serialization=safe_serialization, |
| 109 | + variant=variant, |
| 110 | + ) |
| 111 | + |
| 112 | + idx += 1 |
| 113 | + model_path_to_save = model_path_to_save + f"_{idx}" |
| 114 | + |
| 115 | + @classmethod |
| 116 | + def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs): |
| 117 | + r""" |
| 118 | + Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models. |
| 119 | +
|
| 120 | + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train |
| 121 | + the model, you should first set it back in training mode with `model.train()`. |
| 122 | +
|
| 123 | + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come |
| 124 | + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning |
| 125 | + task. |
| 126 | +
|
| 127 | + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those |
| 128 | + weights are discarded. |
| 129 | +
|
| 130 | + Parameters: |
| 131 | + pretrained_model_path (`os.PathLike`): |
| 132 | + A path to a *directory* containing model weights saved using |
| 133 | + [`~diffusers.pipelines.controlnet.MultiControlNetModel.save_pretrained`], e.g., |
| 134 | + `./my_model_directory/controlnet`. |
| 135 | + torch_dtype (`str` or `torch.dtype`, *optional*): |
| 136 | + Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype |
| 137 | + will be automatically derived from the model's weights. |
| 138 | + output_loading_info(`bool`, *optional*, defaults to `False`): |
| 139 | + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. |
| 140 | + device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): |
| 141 | + A map that specifies where each submodule should go. It doesn't need to be refined to each |
| 142 | + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the |
| 143 | + same device. |
| 144 | +
|
| 145 | + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For |
| 146 | + more information about each option see [designing a device |
| 147 | + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). |
| 148 | + max_memory (`Dict`, *optional*): |
| 149 | + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each |
| 150 | + GPU and the available CPU RAM if unset. |
| 151 | + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): |
| 152 | + Speed up model loading by not initializing the weights and only loading the pre-trained weights. This |
| 153 | + also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the |
| 154 | + model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, |
| 155 | + setting this argument to `True` will raise an error. |
| 156 | + variant (`str`, *optional*): |
| 157 | + If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is |
| 158 | + ignored when using `from_flax`. |
| 159 | + use_safetensors (`bool`, *optional*, defaults to `None`): |
| 160 | + If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the |
| 161 | + `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from |
| 162 | + `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. |
| 163 | + """ |
| 164 | + idx = 0 |
| 165 | + controlnets = [] |
| 166 | + |
| 167 | + # load controlnet and append to list until no controlnet directory exists anymore |
| 168 | + # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained` |
| 169 | + # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ... |
| 170 | + model_path_to_load = pretrained_model_path |
| 171 | + while os.path.isdir(model_path_to_load): |
| 172 | + controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs) |
| 173 | + controlnets.append(controlnet) |
| 174 | + |
| 175 | + idx += 1 |
| 176 | + model_path_to_load = pretrained_model_path + f"_{idx}" |
| 177 | + |
| 178 | + logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.") |
| 179 | + |
| 180 | + if len(controlnets) == 0: |
| 181 | + raise ValueError( |
| 182 | + f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}." |
| 183 | + ) |
| 184 | + |
| 185 | + return cls(controlnets) |
0 commit comments