|
35 | 35 |
|
36 | 36 | from diffusers import (
|
37 | 37 | AutoencoderKL,
|
| 38 | + ConfigMixin, |
38 | 39 | DDIMPipeline,
|
39 | 40 | DDIMScheduler,
|
40 | 41 | DDPMPipeline,
|
|
44 | 45 | EulerAncestralDiscreteScheduler,
|
45 | 46 | EulerDiscreteScheduler,
|
46 | 47 | LMSDiscreteScheduler,
|
| 48 | + ModelMixin, |
47 | 49 | PNDMScheduler,
|
48 | 50 | StableDiffusionImg2ImgPipeline,
|
49 | 51 | StableDiffusionInpaintPipelineLegacy,
|
|
77 | 79 | enable_full_determinism()
|
78 | 80 |
|
79 | 81 |
|
| 82 | +class CustomEncoder(ModelMixin, ConfigMixin): |
| 83 | + def __init__(self): |
| 84 | + super().__init__() |
| 85 | + |
| 86 | + |
| 87 | +class CustomPipeline(DiffusionPipeline): |
| 88 | + def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler): |
| 89 | + super().__init__() |
| 90 | + self.register_modules(encoder=encoder, scheduler=scheduler) |
| 91 | + |
| 92 | + |
80 | 93 | class DownloadTests(unittest.TestCase):
|
81 | 94 | def test_one_request_upon_cached(self):
|
82 | 95 | # TODO: For some reason this test fails on MPS where no HEAD call is made.
|
@@ -695,6 +708,20 @@ def test_local_custom_pipeline_file(self):
|
695 | 708 | # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
|
696 | 709 | assert output_str == "This is a local test"
|
697 | 710 |
|
| 711 | + def test_custom_model_and_pipeline(self): |
| 712 | + pipe = CustomPipeline( |
| 713 | + encoder=CustomEncoder(), |
| 714 | + scheduler=DDIMScheduler(), |
| 715 | + ) |
| 716 | + |
| 717 | + with tempfile.TemporaryDirectory() as tmpdirname: |
| 718 | + pipe.save_pretrained(tmpdirname) |
| 719 | + |
| 720 | + pipe_new = CustomPipeline.from_pretrained(tmpdirname) |
| 721 | + pipe_new.save_pretrained(tmpdirname) |
| 722 | + |
| 723 | + assert dict(pipe_new.config) == dict(pipe.config) |
| 724 | + |
698 | 725 | @slow
|
699 | 726 | @require_torch_gpu
|
700 | 727 | def test_download_from_git(self):
|
|
0 commit comments