mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow custom pipeline loading (#3504)
This commit is contained in:
committed by
GitHub
parent
b134f6a8b6
commit
d4197bf4d7
@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
|
||||
module_path_items = module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
|
||||
@@ -35,6 +35,7 @@ from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPText
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
ConfigMixin,
|
||||
DDIMPipeline,
|
||||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
@@ -44,6 +45,7 @@ from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
ModelMixin,
|
||||
PNDMScheduler,
|
||||
StableDiffusionImg2ImgPipeline,
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
@@ -77,6 +79,17 @@ from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CustomEncoder(ModelMixin, ConfigMixin):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
||||
class CustomPipeline(DiffusionPipeline):
|
||||
def __init__(self, encoder: CustomEncoder, scheduler: DDIMScheduler):
|
||||
super().__init__()
|
||||
self.register_modules(encoder=encoder, scheduler=scheduler)
|
||||
|
||||
|
||||
class DownloadTests(unittest.TestCase):
|
||||
def test_one_request_upon_cached(self):
|
||||
# TODO: For some reason this test fails on MPS where no HEAD call is made.
|
||||
@@ -695,6 +708,20 @@ class CustomPipelineTests(unittest.TestCase):
|
||||
# compare to https://github.com/huggingface/diffusers/blob/main/tests/fixtures/custom_pipeline/pipeline.py#L102
|
||||
assert output_str == "This is a local test"
|
||||
|
||||
def test_custom_model_and_pipeline(self):
|
||||
pipe = CustomPipeline(
|
||||
encoder=CustomEncoder(),
|
||||
scheduler=DDIMScheduler(),
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
|
||||
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
|
||||
pipe_new.save_pretrained(tmpdirname)
|
||||
|
||||
assert dict(pipe_new.config) == dict(pipe.config)
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
def test_download_from_git(self):
|
||||
|
||||
Reference in New Issue
Block a user