From fbca2e0a7a6643cd59de498ad7de0f63182edd19 Mon Sep 17 00:00:00 2001 From: Eugene Antropov Date: Wed, 30 Aug 2023 06:30:53 +0300 Subject: [PATCH] Add loading ckpt from file for SDXL controlNet (#4683) * Add load ckpt from file for ControlNet SDXL * Reformat code * Resort imports --------- Co-authored-by: Patrick von Platen --- .../controlnet/pipeline_controlnet_sd_xl.py | 7 ++-- .../stable_diffusion/convert_from_ckpt.py | 33 +++++++++++++------ 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index e0f039c7d5..c99f007f9c 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -26,7 +26,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz from diffusers.utils.import_utils import is_invisible_watermark_available from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( AttnProcessor2_0, @@ -102,7 +102,9 @@ EXAMPLE_DOC_STRING = """ """ -class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): +class StableDiffusionXLControlNetPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): r""" Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. @@ -112,6 +114,7 @@ class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoa In addition the pipeline inherits the following loading methods: - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] Args: vae ([`AutoencoderKL`]): diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 13528ae4ab..22b424bf16 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1599,16 +1599,29 @@ def download_from_original_stable_diffusion_ckpt( for param_name, param in converted_unet_checkpoint.items(): set_module_tensor_to_device(unet, param_name, "cpu", value=param) - pipe = pipeline_class( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - text_encoder_2=text_encoder_2, - tokenizer_2=tokenizer_2, - unet=unet, - scheduler=scheduler, - force_zeros_for_empty_prompt=True, - ) + if controlnet: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) + else: + pipe = pipeline_class( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + unet=unet, + scheduler=scheduler, + force_zeros_for_empty_prompt=True, + ) else: tokenizer = None text_encoder = None