From 2de9e2df368241cf13f859cf51514cea4e53aed5 Mon Sep 17 00:00:00 2001 From: "Jason C.H" Date: Wed, 7 Jun 2023 05:39:11 +0800 Subject: [PATCH] Fix from_ckpt for Stable Diffusion 2.x (#3662) --- src/diffusers/loaders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6ecc701f83..4b7bb69535 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1442,23 +1442,25 @@ class FromCkptMixin: # TODO: For now we only support stable diffusion stable_unclip = None + model_type = None controlnet = False if pipeline_name == "StableDiffusionControlNetPipeline": - model_type = "FrozenCLIPEmbedder" + # Model type will be inferred from the checkpoint. controlnet = True elif "StableDiffusion" in pipeline_name: - model_type = "FrozenCLIPEmbedder" + # Model type will be inferred from the checkpoint. + pass elif pipeline_name == "StableUnCLIPPipeline": - model_type == "FrozenOpenCLIPEmbedder" + model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "txt2img" elif pipeline_name == "StableUnCLIPImg2ImgPipeline": - model_type == "FrozenOpenCLIPEmbedder" + model_type = "FrozenOpenCLIPEmbedder" stable_unclip = "img2img" elif pipeline_name == "PaintByExamplePipeline": - model_type == "PaintByExample" + model_type = "PaintByExample" elif pipeline_name == "LDMTextToImagePipeline": - model_type == "LDMTextToImage" + model_type = "LDMTextToImage" else: raise ValueError(f"Unhandled pipeline class: {pipeline_name}")