From c9c5436c94a720ece3ce05b82c86ea21ca86a656 Mon Sep 17 00:00:00 2001 From: Lukas Kuhn Date: Tue, 14 Nov 2023 11:35:26 +0100 Subject: [PATCH] download_from_original_stable_diffusion_ckpt initializes correct default pipeline for SDXL (#5784) * feat: sdxl will be automatically detected as pipeline_class * fix: formatting * fix: formatting with black * fix: import pipeline wrongly sorted --- .../pipelines/stable_diffusion/convert_from_ckpt.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 8c1d52ca83..35466f008f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -1232,13 +1232,11 @@ def download_from_original_stable_diffusion_ckpt( StableDiffusionPipeline, StableDiffusionUpscalePipeline, StableDiffusionXLImg2ImgPipeline, + StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline, StableUnCLIPPipeline, ) - if pipeline_class is None: - pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline - if prediction_type == "v-prediction": prediction_type = "v_prediction" @@ -1333,6 +1331,13 @@ def download_from_original_stable_diffusion_ckpt( if image_size is None: image_size = 1024 + if pipeline_class is None: + # Check if we have a SDXL or SD model and initialize default pipeline + if model_type not in ["SDXL", "SDXL-Refiner"]: + pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline + else: + pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline + if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline: num_in_channels = 9 if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline: