mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
download_from_original_stable_diffusion_ckpt initializes correct default pipeline for SDXL (#5784)
* feat: sdxl will be automatically detected as pipeline_class * fix: formatting * fix: formatting with black * fix: import pipeline wrongly sorted
This commit is contained in:
@@ -1232,13 +1232,11 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
|
||||
|
||||
if prediction_type == "v-prediction":
|
||||
prediction_type = "v_prediction"
|
||||
|
||||
@@ -1333,6 +1331,13 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if image_size is None:
|
||||
image_size = 1024
|
||||
|
||||
if pipeline_class is None:
|
||||
# Check if we have a SDXL or SD model and initialize default pipeline
|
||||
if model_type not in ["SDXL", "SDXL-Refiner"]:
|
||||
pipeline_class = StableDiffusionPipeline if not controlnet else StableDiffusionControlNetPipeline
|
||||
else:
|
||||
pipeline_class = StableDiffusionXLPipeline if model_type == "SDXL" else StableDiffusionXLImg2ImgPipeline
|
||||
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionInpaintPipeline:
|
||||
num_in_channels = 9
|
||||
if num_in_channels is None and pipeline_class == StableDiffusionUpscalePipeline:
|
||||
|
||||
Reference in New Issue
Block a user