mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -55,6 +55,15 @@ MODEL_TYPE_FROM_PIPELINE_CLASS = {
|
||||
"StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder",
|
||||
"StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder",
|
||||
}
|
||||
PIPELINE_COMPONENTS = {
|
||||
"unet": ,
|
||||
"vae": "AutoencoderKL",
|
||||
"text_encoder": "CLIPTextModel",
|
||||
"text_encoder_2": "CLIPTextModel",
|
||||
"tokenizer": "CLIPTokenizer",
|
||||
"tokenizer_2": "CLIPTokenizer",
|
||||
"scheduler": "DiffusionScheduler",
|
||||
}
|
||||
|
||||
|
||||
def extract_pipeline_component_names(pipeline_class):
|
||||
@@ -120,12 +129,41 @@ def infer_model_type(pipeline_class_name):
|
||||
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
|
||||
|
||||
|
||||
def build_component(component_name, original_config, checkpoint, **kwargs):
|
||||
def build_component(pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs):
|
||||
if component_name in kwargs:
|
||||
return kwargs.pop(component_name, None)
|
||||
|
||||
component_class = getattr(importlib.import_module("diffusers"), component_name)
|
||||
if component_name == "unet":
|
||||
unet = create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return unet
|
||||
|
||||
if component_name == "controlnet":
|
||||
controlnet = create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return controlnet
|
||||
|
||||
if component_name == "vae":
|
||||
vae = create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs)
|
||||
return vae
|
||||
|
||||
if component_name in ["text_encoder", "text_encoder_2"]:
|
||||
text_encoder = create_text_encoder_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return text_encoder
|
||||
|
||||
if component_name in ["tokenizer", "tokenizer_2"]:
|
||||
tokenizer = create_tokenizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
if component_name == "scheduler":
|
||||
scheduler = create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return scheduler
|
||||
|
||||
if component_name == "image_normalizer":
|
||||
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return image_normalizer
|
||||
|
||||
if component_name == "image_normalizer":
|
||||
image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs)
|
||||
return image_normalizer
|
||||
|
||||
|
||||
return
|
||||
@@ -332,7 +370,7 @@ class FromSingleFileMixin:
|
||||
|
||||
pipeline_components = {}
|
||||
for component in component_names:
|
||||
pipeline_components[component] = build_component(component, checkpoint, original_config, **kwargs)
|
||||
pipeline_components[component] = build_component(pipeline_class_name, component, checkpoint, original_config, **kwargs)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
@@ -359,7 +397,7 @@ class FromSingleFileMixin:
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
pipe = cls(**pipeline_components, **kwargs)
|
||||
pipe = cls(**pipeline_components)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
@@ -156,7 +156,7 @@ def set_model_type(original_config, model_type=None):
|
||||
else:
|
||||
raise ValueError("Unable to infer model type from config")
|
||||
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
logger.debug(f"No `model_type` given, `model_type` inferred as: {model_type}")
|
||||
|
||||
return model_type
|
||||
|
||||
@@ -897,10 +897,78 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False, text_encoder
|
||||
return text_model
|
||||
|
||||
|
||||
def create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs):
|
||||
def convert_controlnet_checkpoint(
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
use_linear_projection=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
if use_linear_projection is not None:
|
||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
controlnet = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
if "time_embed.0.weight" in checkpoint:
|
||||
skip_extract_state_dict = True
|
||||
else:
|
||||
skip_extract_state_dict = False
|
||||
|
||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint,
|
||||
ctrlnet_config,
|
||||
path=checkpoint_path,
|
||||
extract_ema=extract_ema,
|
||||
controlnet=True,
|
||||
skip_extract_state_dict=skip_extract_state_dict,
|
||||
)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in converted_ctrl_checkpoint.items():
|
||||
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
|
||||
else:
|
||||
controlnet.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet
|
||||
|
||||
|
||||
def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs):
|
||||
if "num_in_channels" in kwargs:
|
||||
num_in_channels = kwargs.pop("num_in_channels")
|
||||
elif pipeline_class_name in [
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetInpaintPipeline"]:
|
||||
num_in_channels = 9
|
||||
elif pipeline_class_name == "StableDiffusionUpscalePipeline":
|
||||
num_in_channels = 7
|
||||
else:
|
||||
num_in_channels = 4
|
||||
|
||||
if "upcast_attention" in kwargs:
|
||||
upcast_attention = kwargs.pop("upcast_attention")
|
||||
|
||||
extract_ema = kwargs.get("extract_ema", False)
|
||||
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["num_in_channels"] = num_in_channels
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
|
||||
Reference in New Issue
Block a user