|
45 | 45 | PNDMScheduler,
|
46 | 46 | PriorTransformer,
|
47 | 47 | StableDiffusionControlNetPipeline,
|
| 48 | + StableDiffusionImg2ImgPipeline, |
| 49 | + StableDiffusionInpaintPipeline, |
48 | 50 | StableDiffusionPipeline,
|
49 | 51 | StableUnCLIPImg2ImgPipeline,
|
50 | 52 | StableUnCLIPPipeline,
|
@@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt(
|
979 | 981 | image_size: int = 512,
|
980 | 982 | prediction_type: str = None,
|
981 | 983 | model_type: str = None,
|
| 984 | + is_img2img: bool = False, |
982 | 985 | extract_ema: bool = False,
|
983 | 986 | scheduler_type: str = "pndm",
|
984 | 987 | num_in_channels: Optional[int] = None,
|
@@ -1018,6 +1021,8 @@ def download_from_original_stable_diffusion_ckpt(
|
1018 | 1021 | model_type (`str`, *optional*, defaults to `None`):
|
1019 | 1022 | The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
|
1020 | 1023 | "FrozenCLIPEmbedder", "PaintByExample"]`.
|
| 1024 | + is_img2img (`bool`, *optional*, defaults to `False`): |
| 1025 | + Whether the model should be loaded as an img2img pipeline. |
1021 | 1026 | extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
|
1022 | 1027 | checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
|
1023 | 1028 | `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
|
@@ -1193,16 +1198,44 @@ def download_from_original_stable_diffusion_ckpt(
|
1193 | 1198 | requires_safety_checker=False,
|
1194 | 1199 | )
|
1195 | 1200 | else:
|
1196 |
| - pipe = StableDiffusionPipeline( |
1197 |
| - vae=vae, |
1198 |
| - text_encoder=text_model, |
1199 |
| - tokenizer=tokenizer, |
1200 |
| - unet=unet, |
1201 |
| - scheduler=scheduler, |
1202 |
| - safety_checker=None, |
1203 |
| - feature_extractor=None, |
1204 |
| - requires_safety_checker=False, |
1205 |
| - ) |
| 1201 | + if ( |
| 1202 | + hasattr(original_config, "model") |
| 1203 | + and hasattr(original_config.model, "target") |
| 1204 | + and "LatentInpaintDiffusion" in original_config.model.target |
| 1205 | + ): |
| 1206 | + pipe = StableDiffusionInpaintPipeline( |
| 1207 | + vae=vae, |
| 1208 | + text_encoder=text_model, |
| 1209 | + tokenizer=tokenizer, |
| 1210 | + unet=unet, |
| 1211 | + scheduler=scheduler, |
| 1212 | + safety_checker=None, |
| 1213 | + feature_extractor=None, |
| 1214 | + requires_safety_checker=False, |
| 1215 | + ) |
| 1216 | + else: |
| 1217 | + if is_img2img: |
| 1218 | + pipe = StableDiffusionImg2ImgPipeline( |
| 1219 | + vae=vae, |
| 1220 | + text_encoder=text_model, |
| 1221 | + tokenizer=tokenizer, |
| 1222 | + unet=unet, |
| 1223 | + scheduler=scheduler, |
| 1224 | + safety_checker=None, |
| 1225 | + feature_extractor=None, |
| 1226 | + requires_safety_checker=False, |
| 1227 | + ) |
| 1228 | + else: |
| 1229 | + pipe = StableDiffusionPipeline( |
| 1230 | + vae=vae, |
| 1231 | + text_encoder=text_model, |
| 1232 | + tokenizer=tokenizer, |
| 1233 | + unet=unet, |
| 1234 | + scheduler=scheduler, |
| 1235 | + safety_checker=None, |
| 1236 | + feature_extractor=None, |
| 1237 | + requires_safety_checker=False, |
| 1238 | + ) |
1206 | 1239 | else:
|
1207 | 1240 | image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
1208 | 1241 | original_config, clip_stats_path=clip_stats_path, device=device
|
@@ -1293,15 +1326,41 @@ def download_from_original_stable_diffusion_ckpt(
|
1293 | 1326 | feature_extractor=feature_extractor,
|
1294 | 1327 | )
|
1295 | 1328 | else:
|
1296 |
| - pipe = StableDiffusionPipeline( |
1297 |
| - vae=vae, |
1298 |
| - text_encoder=text_model, |
1299 |
| - tokenizer=tokenizer, |
1300 |
| - unet=unet, |
1301 |
| - scheduler=scheduler, |
1302 |
| - safety_checker=safety_checker, |
1303 |
| - feature_extractor=feature_extractor, |
1304 |
| - ) |
| 1329 | + if ( |
| 1330 | + hasattr(original_config, "model") |
| 1331 | + and hasattr(original_config.model, "target") |
| 1332 | + and "LatentInpaintDiffusion" in original_config.model.target |
| 1333 | + ): |
| 1334 | + pipe = StableDiffusionInpaintPipeline( |
| 1335 | + vae=vae, |
| 1336 | + text_encoder=text_model, |
| 1337 | + tokenizer=tokenizer, |
| 1338 | + unet=unet, |
| 1339 | + scheduler=scheduler, |
| 1340 | + safety_checker=safety_checker, |
| 1341 | + feature_extractor=feature_extractor, |
| 1342 | + ) |
| 1343 | + else: |
| 1344 | + if is_img2img: |
| 1345 | + pipe = StableDiffusionImg2ImgPipeline( |
| 1346 | + vae=vae, |
| 1347 | + text_encoder=text_model, |
| 1348 | + tokenizer=tokenizer, |
| 1349 | + unet=unet, |
| 1350 | + scheduler=scheduler, |
| 1351 | + safety_checker=safety_checker, |
| 1352 | + feature_extractor=feature_extractor, |
| 1353 | + ) |
| 1354 | + else: |
| 1355 | + pipe = StableDiffusionPipeline( |
| 1356 | + vae=vae, |
| 1357 | + text_encoder=text_model, |
| 1358 | + tokenizer=tokenizer, |
| 1359 | + unet=unet, |
| 1360 | + scheduler=scheduler, |
| 1361 | + safety_checker=safety_checker, |
| 1362 | + feature_extractor=feature_extractor, |
| 1363 | + ) |
1305 | 1364 | else:
|
1306 | 1365 | text_config = create_ldm_bert_config(original_config)
|
1307 | 1366 | text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
|
0 commit comments