mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add loading ckpt from file for SDXL controlNet (#4683)
* Add load ckpt from file for ControlNet SDXL * Reformat code * Resort imports --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
|
||||
from diffusers.utils.import_utils import is_invisible_watermark_available
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
@@ -102,7 +102,9 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
|
||||
class StableDiffusionXLControlNetPipeline(
|
||||
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
|
||||
|
||||
@@ -112,6 +114,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
|
||||
In addition the pipeline inherits the following loading methods:
|
||||
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
||||
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
||||
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
|
||||
@@ -1599,16 +1599,29 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
for param_name, param in converted_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
if controlnet:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
controlnet=controlnet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
pipe = pipeline_class(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_encoder_2=text_encoder_2,
|
||||
tokenizer_2=tokenizer_2,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
force_zeros_for_empty_prompt=True,
|
||||
)
|
||||
else:
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
|
||||
Reference in New Issue
Block a user