From 48d203eeea8568365b224581d0b197332cbcd4d1 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sun, 25 Jun 2023 23:36:28 +0000 Subject: [PATCH] Get refiner to work --- .../stable_diffusion/convert_from_ckpt.py | 31 +++++++++++++------ .../pipeline_stable_diffusion_xl.py | 6 ++-- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 147e1af1d5..e1a83b93db 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -270,16 +270,21 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa if use_linear_projection: # stable diffusion 2-base-512 and 2-768 if head_dim is None: - head_dim = [5 * c for c in list(unet_params.channel_mult)] + head_dim_mult = unet_params.model_channels // unet_params.num_head_channels + head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)] class_embed_type = None addition_embed_type = None addition_time_embed_dim = None projection_class_embeddings_input_dim = None + context_dim = None + + if unet_params.context_dim is not None: + context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0] if "num_classes" in unet_params: if unet_params.num_classes == "sequential": - if unet_params.context_dim == 2048: + if context_dim in [2048, 1280]: # SDXL addition_embed_type = "text_time" addition_time_embed_dim = 256 @@ -296,7 +301,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "down_block_types": tuple(down_block_types), "block_out_channels": tuple(block_out_channels), "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": unet_params.context_dim, + "cross_attention_dim": context_dim, "attention_head_dim": head_dim, "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, @@ -1272,6 +1277,8 @@ def download_from_original_stable_diffusion_ckpt( elif model_type is None and original_config.model.params.network_config is not None: if original_config.model.params.network_config.params.context_dim == 2048: model_type = "SDXL" + else: + model_type = "SDXL-Refiner" if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) @@ -1400,12 +1407,18 @@ def download_from_original_stable_diffusion_ckpt( safety_checker=safety_checker, feature_extractor=feature_extractor, ) - elif model_type == "SDXL": - tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") - text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) - tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") - text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280) - text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + elif model_type in ["SDXL", "SDXL-Refiner"]: + if model_type == "SDXL": + tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.") + else: + tokenizer = None + text_encoder = None + tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!") + text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.") + pipe = StableDiffusionXLPipeline( vae=vae, text_encoder=text_encoder, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 0610936b0a..c99f054ef3 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -104,7 +104,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline): feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ - _optional_components = ["safety_checker", "feature_extractor"] + _optional_components = ["safety_checker", "feature_extractor", "tokenizer", "text_encoder"] def __init__( self, @@ -304,8 +304,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline): batch_size = prompt_embeds.shape[0] # Define tokenizers and text encoders - tokenizers = [self.tokenizer, self.tokenizer_2] - text_encoders = [self.text_encoder.to(device), self.text_encoder_2.to(device)] + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] if prompt_embeds is None: # textual inversion: procecss multi-vector tokens if necessary