Skip to content

Commit d4197bf

Browse files
Allow custom pipeline loading (#3504)
1 parent b134f6a commit d4197bf

File tree

2 files changed

+34
-3
lines changed

2 files changed

+34
-3
lines changed

src/diffusers/pipelines/pipeline_utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -491,15 +491,19 @@ def register_modules(self, **kwargs):
491491
library = module.__module__.split(".")[0]
492492

493493
# check if the module is a pipeline module
494-
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
494+
module_path_items = module.__module__.split(".")
495+
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
496+
495497
path = module.__module__.split(".")
496498
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
497499

498500
# if library is not in LOADABLE_CLASSES, then it is a custom module.
499501
# Or if it's a pipeline module, then the module is inside the pipeline
500502
# folder so we set the library to module name.
501-
if library not in LOADABLE_CLASSES or is_pipeline_module:
503+
if is_pipeline_module:
502504
library = pipeline_dir
505+
elif library not in LOADABLE_CLASSES:
506+
library = module.__module__
503507

504508
# retrieve class_name
505509
class_name = module.__class__.__name__
@@ -1039,7 +1043,7 @@ def load_module(name, value):
10391043

10401044
# 6.2 Define all importable classes
10411045
is_pipeline_module = hasattr(pipelines, library_name)
1042-
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
1046+
importable_classes = ALL_IMPORTABLE_CLASSES
10431047
loaded_sub_model = None
10441048

10451049
# 6.3 Use passed sub model or load class_name from library_name

tests/pipelines/test_pipelines.py

+27
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from diffusers import (
3737
AutoencoderKL,
38+
ConfigMixin,
3839
DDIMPipeline,
3940
DDIMScheduler,
4041
DDPMPipeline,
@@ -44,6 +45,7 @@
4445
EulerAncestralDiscreteScheduler,
4546
EulerDiscreteScheduler,
4647
LMSDiscreteScheduler,
48+
ModelMixin,
4749
PNDMScheduler,
4850
StableDiffusionImg2ImgPipeline,
4951
StableDiffusionInpaintPipelineLegacy,
@@ -77,6 +79,17 @@
7779
enable_full_determinism()
7880

7981

82+
class CustomEncoder(ModelMixin, ConfigMixin):
83+
def __init__(self):
84+
super().__init__()
85+
86+
87+
class CustomPipeline(DiffusionPipeline):
88+
def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler):
89+
super().__init__()
90+
self.register_modules(encoder=encoder, scheduler=scheduler)
91+
92+
8093
class DownloadTests(unittest.TestCase):
8194
def test_one_request_upon_cached(self):
8295
# TODO: For some reason this test fails on MPS where no HEAD call is made.
@@ -695,6 +708,20 @@ def test_local_custom_pipeline_file(self):
695708
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
696709
assert output_str == "This is a local test"
697710

711+
def test_custom_model_and_pipeline(self):
712+
pipe = CustomPipeline(
713+
encoder=CustomEncoder(),
714+
scheduler=DDIMScheduler(),
715+
)
716+
717+
with tempfile.TemporaryDirectory() as tmpdirname:
718+
pipe.save_pretrained(tmpdirname)
719+
720+
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
721+
pipe_new.save_pretrained(tmpdirname)
722+
723+
assert dict(pipe_new.config) == dict(pipe.config)
724+
698725
@slow
699726
@require_torch_gpu
700727
def test_download_from_git(self):

0 commit comments

Comments
 (0)