1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add loading ckpt from file for SDXL controlNet (#4683)

* Add load ckpt from file for ControlNet SDXL

* Reformat code

* Resort imports

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Eugene Antropov
2023-08-30 06:30:53 +03:00
committed by GitHub
parent 3768d4d77c
commit fbca2e0a7a
2 changed files with 28 additions and 12 deletions

View File

@@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.utils.import_utils import is_invisible_watermark_available
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.attention_processor import (
AttnProcessor2_0,
@@ -102,7 +102,9 @@ EXAMPLE_DOC_STRING = """
"""
class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin):
class StableDiffusionXLControlNetPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
r"""
Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
@@ -112,6 +114,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa
In addition the pipeline inherits the following loading methods:
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
Args:
vae ([`AutoencoderKL`]):

View File

@@ -1599,16 +1599,29 @@ def download_from_original_stable_diffusion_ckpt(
for param_name, param in converted_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
if controlnet:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
controlnet=controlnet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
pipe = pipeline_class(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
force_zeros_for_empty_prompt=True,
)
else:
tokenizer = None
text_encoder = None