1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Dhruv Nair
2023-12-26 05:09:03 +00:00
parent 8b7eecd4d4
commit 0cd1be42d3
2 changed files with 112 additions and 6 deletions

View File

@@ -55,6 +55,15 @@ MODEL_TYPE_FROM_PIPELINE_CLASS = {
"StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder",
"StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder",
}
PIPELINE_COMPONENTS = {
"unet": ,
"vae": "AutoencoderKL",
"text_encoder": "CLIPTextModel",
"text_encoder_2": "CLIPTextModel",
"tokenizer": "CLIPTokenizer",
"tokenizer_2": "CLIPTokenizer",
"scheduler": "DiffusionScheduler",
}
def extract_pipeline_component_names(pipeline_class):
@@ -120,12 +129,41 @@ def infer_model_type(pipeline_class_name):
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
def build_component(component_name, original_config, checkpoint, **kwargs):
def build_component(pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
if component_name in kwargs:
return kwargs.pop(component_name, None)
component_class = getattr(importlib.import_module("diffusers"), component_name)
if component_name == "unet":
unet = create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return unet
if component_name == "controlnet":
controlnet = create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return controlnet
if component_name == "vae":
vae = create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
return vae
if component_name in ["text_encoder", "text_encoder_2"]:
text_encoder = create_text_encoder_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return text_encoder
if component_name in ["tokenizer", "tokenizer_2"]:
tokenizer = create_tokenizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return tokenizer
if component_name == "scheduler":
scheduler = create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return scheduler
if component_name == "image_normalizer":
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return image_normalizer
if component_name == "image_normalizer":
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
return image_normalizer
return
@@ -332,7 +370,7 @@ class FromSingleFileMixin:
pipeline_components = {}
for component in component_names:
pipeline_components[component] = build_component(component, checkpoint, original_config, **kwargs)
pipeline_components[component] = build_component(pipeline_class_name, component, checkpoint, original_config, **kwargs)
pipe = download_from_original_stable_diffusion_ckpt(
pretrained_model_link_or_path,
@@ -359,7 +397,7 @@ class FromSingleFileMixin:
local_files_only=local_files_only,
)
pipe = cls(**pipeline_components, **kwargs)
pipe = cls(**pipeline_components)
if torch_dtype is not None:
pipe.to(dtype=torch_dtype)

View File

@@ -156,7 +156,7 @@ def set_model_type(original_config, model_type=None):
else:
raise ValueError("Unable to infer model type from config")
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}")
return model_type
@@ -897,10 +897,78 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
return text_model
def create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs):
def convert_controlnet_checkpoint(
checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=None,
cross_attention_dim=None,
):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
if use_linear_projection is not None:
ctrlnet_config["use_linear_projection"] = use_linear_projection
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
controlnet = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
skip_extract_state_dict = True
else:
skip_extract_state_dict = False
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint,
ctrlnet_config,
path=checkpoint_path,
extract_ema=extract_ema,
controlnet=True,
skip_extract_state_dict=skip_extract_state_dict,
)
if is_accelerate_available():
for param_name, param in converted_ctrl_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
else:
controlnet.load_state_dict(converted_ctrl_checkpoint)
return controlnet
def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs):
if "num_in_channels" in kwargs:
num_in_channels = kwargs.pop("num_in_channels")
elif pipeline_class_name in [
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLControlNetInpaintPipeline"]:
num_in_channels = 9
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
num_in_channels = 7
else:
num_in_channels = 4
if "upcast_attention" in kwargs:
upcast_attention = kwargs.pop("upcast_attention")
extract_ema = kwargs.get("extract_ema", False)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["num_in_channels"] = num_in_channels
unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=path, extract_ema=extract_ema