1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Get refiner to work

This commit is contained in:
Patrick von Platen
2023-06-25 23:36:28 +00:00
parent ea4cf25928
commit 48d203eeea
2 changed files with 25 additions and 12 deletions

View File

@@ -270,16 +270,21 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5 * c for c in list(unet_params.channel_mult)]
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
class_embed_type = None
addition_embed_type = None
addition_time_embed_dim = None
projection_class_embeddings_input_dim = None
context_dim = None
if unet_params.context_dim is not None:
context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
if unet_params.context_dim == 2048:
if context_dim in [2048, 1280]:
# SDXL
addition_embed_type = "text_time"
addition_time_embed_dim = 256
@@ -296,7 +301,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": unet_params.context_dim,
"cross_attention_dim": context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
@@ -1272,6 +1277,8 @@ def download_from_original_stable_diffusion_ckpt(
elif model_type is None and original_config.model.params.network_config is not None:
if original_config.model.params.network_config.params.context_dim == 2048:
model_type = "SDXL"
else:
model_type = "SDXL-Refiner"
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
@@ -1400,12 +1407,18 @@ def download_from_original_stable_diffusion_ckpt(
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
elif model_type == "SDXL":
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280)
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
elif model_type in ["SDXL", "SDXL-Refiner"]:
if model_type == "SDXL":
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
else:
tokenizer = None
text_encoder = None
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,

View File

@@ -104,7 +104,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
_optional_components = ["safety_checker", "feature_extractor"]
_optional_components = ["safety_checker", "feature_extractor", "tokenizer", "text_encoder"]
def __init__(
self,
@@ -304,8 +304,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder.to(device), self.text_encoder_2.to(device)]
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
if prompt_embeds is None:
# textual inversion: procecss multi-vector tokens if necessary