mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
allow passing components to connected pipelines when use the combined pipeline (#4883)
* fix * add test --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -1147,8 +1147,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
return connected_passed_kwargs
|
||||
|
||||
connected_pipes = {
|
||||
prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy())
|
||||
prefix: DiffusionPipeline.from_pretrained(
|
||||
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
||||
)
|
||||
for prefix, repo_id in connected_pipes.items()
|
||||
if repo_id is not None
|
||||
}
|
||||
|
||||
@@ -18,7 +18,13 @@ import unittest
|
||||
import torch
|
||||
from huggingface_hub import ModelCard
|
||||
|
||||
from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline
|
||||
from diffusers import (
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
KandinskyV22CombinedPipeline,
|
||||
KandinskyV22Pipeline,
|
||||
KandinskyV22PriorPipeline,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS
|
||||
|
||||
|
||||
@@ -101,3 +107,22 @@ class CombinedPipelineFastTest(unittest.TestCase):
|
||||
assert dict(component.config) == dict(comp.config)
|
||||
else:
|
||||
assert component.__class__ == comp.__class__
|
||||
|
||||
def test_load_connected_checkpoint_with_passed_obj(self):
|
||||
pipeline = KandinskyV22CombinedPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-kandinsky-v22-decoder"
|
||||
)
|
||||
prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config)
|
||||
scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
|
||||
|
||||
# make sure we pass a different scheduler and prior_scheduler
|
||||
assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__
|
||||
assert pipeline.scheduler.__class__ != scheduler.__class__
|
||||
|
||||
pipeline_new = KandinskyV22CombinedPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-kandinsky-v22-decoder",
|
||||
prior_scheduler=prior_scheduler,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config)
|
||||
assert dict(pipeline_new.scheduler.config) == dict(scheduler.config)
|
||||
|
||||
Reference in New Issue
Block a user