1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow custom pipeline loading (#3504)

This commit is contained in:
Patrick von Platen
2023-05-23 14:20:55 +02:00
committed by GitHub
parent b134f6a8b6
commit d4197bf4d7
2 changed files with 34 additions and 3 deletions

View File

@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
module_path_items = module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
# retrieve class_name
class_name = module.__class__.__name__
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
# 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name

View File

@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
from diffusers import (
AutoencoderKL,
ConfigMixin,
DDIMPipeline,
DDIMScheduler,
DDPMPipeline,
@@ -44,6 +45,7 @@ from diffusers import (
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
ModelMixin,
PNDMScheduler,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipelineLegacy,
@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
enable_full_determinism()
class CustomEncoder(ModelMixin, ConfigMixin):
def __init__(self):
super().__init__()
class CustomPipeline(DiffusionPipeline):
def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler):
super().__init__()
self.register_modules(encoder=encoder, scheduler=scheduler)
class DownloadTests(unittest.TestCase):
def test_one_request_upon_cached(self):
# TODO: For some reason this test fails on MPS where no HEAD call is made.
@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
assert output_str == "This is a local test"
def test_custom_model_and_pipeline(self):
pipe = CustomPipeline(
encoder=CustomEncoder(),
scheduler=DDIMScheduler(),
)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
pipe_new.save_pretrained(tmpdirname)
assert dict(pipe_new.config) == dict(pipe.config)
@slow
@require_torch_gpu
def test_download_from_git(self):