Skip to content

Commit bdeff4d

Browse files
authored
[ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model (#2705)
* [ckpt loader] Allow loading the Inpaint and Img2Img pipelines, while loading a ckpt model * Address review comment from PR * PyLint formatting * Some more pylint fixes, unrelated to our change * Another pylint fix * Styling fix
1 parent fc18839 commit bdeff4d

File tree

1 file changed

+78
-19
lines changed

1 file changed

+78
-19
lines changed

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 78 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
PNDMScheduler,
4646
PriorTransformer,
4747
StableDiffusionControlNetPipeline,
48+
StableDiffusionImg2ImgPipeline,
49+
StableDiffusionInpaintPipeline,
4850
StableDiffusionPipeline,
4951
StableUnCLIPImg2ImgPipeline,
5052
StableUnCLIPPipeline,
@@ -979,6 +981,7 @@ def download_from_original_stable_diffusion_ckpt(
979981
image_size: int = 512,
980982
prediction_type: str = None,
981983
model_type: str = None,
984+
is_img2img: bool = False,
982985
extract_ema: bool = False,
983986
scheduler_type: str = "pndm",
984987
num_in_channels: Optional[int] = None,
@@ -1018,6 +1021,8 @@ def download_from_original_stable_diffusion_ckpt(
10181021
model_type (`str`, *optional*, defaults to `None`):
10191022
The pipeline type. `None` to automatically infer, or one of `["FrozenOpenCLIPEmbedder",
10201023
"FrozenCLIPEmbedder", "PaintByExample"]`.
1024+
is_img2img (`bool`, *optional*, defaults to `False`):
1025+
Whether the model should be loaded as an img2img pipeline.
10211026
extract_ema (`bool`, *optional*, defaults to `False`): Only relevant for
10221027
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights or not. Defaults to
10231028
`False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher quality images for
@@ -1193,16 +1198,44 @@ def download_from_original_stable_diffusion_ckpt(
11931198
requires_safety_checker=False,
11941199
)
11951200
else:
1196-
pipe = StableDiffusionPipeline(
1197-
vae=vae,
1198-
text_encoder=text_model,
1199-
tokenizer=tokenizer,
1200-
unet=unet,
1201-
scheduler=scheduler,
1202-
safety_checker=None,
1203-
feature_extractor=None,
1204-
requires_safety_checker=False,
1205-
)
1201+
if (
1202+
hasattr(original_config, "model")
1203+
and hasattr(original_config.model, "target")
1204+
and "LatentInpaintDiffusion" in original_config.model.target
1205+
):
1206+
pipe = StableDiffusionInpaintPipeline(
1207+
vae=vae,
1208+
text_encoder=text_model,
1209+
tokenizer=tokenizer,
1210+
unet=unet,
1211+
scheduler=scheduler,
1212+
safety_checker=None,
1213+
feature_extractor=None,
1214+
requires_safety_checker=False,
1215+
)
1216+
else:
1217+
if is_img2img:
1218+
pipe = StableDiffusionImg2ImgPipeline(
1219+
vae=vae,
1220+
text_encoder=text_model,
1221+
tokenizer=tokenizer,
1222+
unet=unet,
1223+
scheduler=scheduler,
1224+
safety_checker=None,
1225+
feature_extractor=None,
1226+
requires_safety_checker=False,
1227+
)
1228+
else:
1229+
pipe = StableDiffusionPipeline(
1230+
vae=vae,
1231+
text_encoder=text_model,
1232+
tokenizer=tokenizer,
1233+
unet=unet,
1234+
scheduler=scheduler,
1235+
safety_checker=None,
1236+
feature_extractor=None,
1237+
requires_safety_checker=False,
1238+
)
12061239
else:
12071240
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
12081241
original_config, clip_stats_path=clip_stats_path, device=device
@@ -1293,15 +1326,41 @@ def download_from_original_stable_diffusion_ckpt(
12931326
feature_extractor=feature_extractor,
12941327
)
12951328
else:
1296-
pipe = StableDiffusionPipeline(
1297-
vae=vae,
1298-
text_encoder=text_model,
1299-
tokenizer=tokenizer,
1300-
unet=unet,
1301-
scheduler=scheduler,
1302-
safety_checker=safety_checker,
1303-
feature_extractor=feature_extractor,
1304-
)
1329+
if (
1330+
hasattr(original_config, "model")
1331+
and hasattr(original_config.model, "target")
1332+
and "LatentInpaintDiffusion" in original_config.model.target
1333+
):
1334+
pipe = StableDiffusionInpaintPipeline(
1335+
vae=vae,
1336+
text_encoder=text_model,
1337+
tokenizer=tokenizer,
1338+
unet=unet,
1339+
scheduler=scheduler,
1340+
safety_checker=safety_checker,
1341+
feature_extractor=feature_extractor,
1342+
)
1343+
else:
1344+
if is_img2img:
1345+
pipe = StableDiffusionImg2ImgPipeline(
1346+
vae=vae,
1347+
text_encoder=text_model,
1348+
tokenizer=tokenizer,
1349+
unet=unet,
1350+
scheduler=scheduler,
1351+
safety_checker=safety_checker,
1352+
feature_extractor=feature_extractor,
1353+
)
1354+
else:
1355+
pipe = StableDiffusionPipeline(
1356+
vae=vae,
1357+
text_encoder=text_model,
1358+
tokenizer=tokenizer,
1359+
unet=unet,
1360+
scheduler=scheduler,
1361+
safety_checker=safety_checker,
1362+
feature_extractor=feature_extractor,
1363+
)
13051364
else:
13061365
text_config = create_ldm_bert_config(original_config)
13071366
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)

0 commit comments

Comments
 (0)