mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Get refiner to work
This commit is contained in:
@@ -270,16 +270,21 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim = [5 * c for c in list(unet_params.channel_mult)]
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
addition_embed_type = None
|
||||
addition_time_embed_dim = None
|
||||
projection_class_embeddings_input_dim = None
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
if unet_params.context_dim == 2048:
|
||||
if context_dim in [2048, 1280]:
|
||||
# SDXL
|
||||
addition_embed_type = "text_time"
|
||||
addition_time_embed_dim = 256
|
||||
@@ -296,7 +301,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
"down_block_types": tuple(down_block_types),
|
||||
"block_out_channels": tuple(block_out_channels),
|
||||
"layers_per_block": unet_params.num_res_blocks,
|
||||
"cross_attention_dim": unet_params.context_dim,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": head_dim,
|
||||
"use_linear_projection": use_linear_projection,
|
||||
"class_embed_type": class_embed_type,
|
||||
@@ -1272,6 +1277,8 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
elif model_type is None and original_config.model.params.network_config is not None:
|
||||
if original_config.model.params.network_config.params.context_dim == 2048:
|
||||
model_type = "SDXL"
|
||||
else:
|
||||
model_type = "SDXL-Refiner"
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
@@ -1400,12 +1407,18 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
elif model_type == "SDXL":
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280)
|
||||
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
|
||||
elif model_type in ["SDXL", "SDXL-Refiner"]:
|
||||
if model_type == "SDXL":
|
||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||
text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.1.model.")
|
||||
else:
|
||||
tokenizer = None
|
||||
text_encoder = None
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", pad_token="!")
|
||||
text_encoder_2 = convert_open_clip_checkpoint(checkpoint, prefix="conditioner.embedders.0.model.")
|
||||
|
||||
pipe = StableDiffusionXLPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
||||
@@ -104,7 +104,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
feature_extractor ([`CLIPImageProcessor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
_optional_components = ["safety_checker", "feature_extractor", "tokenizer", "text_encoder"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -304,8 +304,8 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder.to(device), self.text_encoder_2.to(device)]
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
|
||||
Reference in New Issue
Block a user