From d4197bf4d72f04d4927ff1e7be2f8ee46efebe47 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 23 May 2023 14:20:55 +0200 Subject: [PATCH] Allow custom pipeline loading (#3504) --- src/diffusers/pipelines/pipeline_utils.py | 10 ++++++--- tests/pipelines/test_pipelines.py | 27 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index aed1139a2a..2f56f650ea 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin): library = module.__module__.split(".")[0] # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None + module_path_items = module.__module__.split(".") + pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None + path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: + if is_pipeline_module: library = pipeline_dir + elif library not in LOADABLE_CLASSES: + library = module.__module__ # retrieve class_name class_name = module.__class__.__name__ @@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin): # 6.2 Define all importable classes is_pipeline_module = hasattr(pipelines, library_name) - importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name] + importable_classes = ALL_IMPORTABLE_CLASSES loaded_sub_model = None # 6.3 Use passed sub model or load class_name from library_name diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index a9abb0b4fb..6ec9ff0346 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText from diffusers import ( AutoencoderKL, + ConfigMixin, DDIMPipeline, DDIMScheduler, DDPMPipeline, @@ -44,6 +45,7 @@ from diffusers import ( EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler, + ModelMixin, PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, @@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import ( enable_full_determinism() +class CustomEncoder(ModelMixin, ConfigMixin): + def __init__(self): + super().__init__() + + +class CustomPipeline(DiffusionPipeline): + def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler): + super().__init__() + self.register_modules(encoder=encoder, scheduler=scheduler) + + class DownloadTests(unittest.TestCase): def test_one_request_upon_cached(self): # TODO: For some reason this test fails on MPS where no HEAD call is made. @@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase): # compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102 assert output_str == "This is a local test" + def test_custom_model_and_pipeline(self): + pipe = CustomPipeline( + encoder=CustomEncoder(), + scheduler=DDIMScheduler(), + ) + + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + + pipe_new = CustomPipeline.from_pretrained(tmpdirname) + pipe_new.save_pretrained(tmpdirname) + + assert dict(pipe_new.config) == dict(pipe.config) + @slow @require_torch_gpu def test_download_from_git(self):