diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 0a8f69433a..1e670c954f 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -31,7 +31,14 @@ from ..utils import ( logging, ) from ..utils.import_utils import BACKENDS_MAPPING -from .single_file_utils import download_from_original_stable_diffusion_ckpt, fetch_original_config +from .single_file_utils import ( + create_scheduler_components, + create_stable_unclip_components, + create_unet_model, + create_vae_model, + download_from_original_stable_diffusion_ckpt, + fetch_original_config, +) if is_transformers_available(): @@ -43,26 +50,13 @@ if is_accelerate_available(): logger = logging.get_logger(__name__) -DIFFUSER_PIPELINE_CONFIGS = { - "StableDiffusionPipeline": None, - "StableDiffusionImg2ImgPipeline": None, - "StableDiffusionInpaintPipeline": None, - "StableDiffusionControlNetPipeline": None, -} - VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"] -MODEL_TYPE_FROM_PIPELINE_CLASS = { +TEXT_ENCODER_FROM_PIPELINE_CLASS = { "StableUnCLIPPipeline": "FrozenOpenCLIPEmbedder", "StableUnCLIPImg2ImgPipeline": "FrozenOpenCLIPEmbedder", -} -PIPELINE_COMPONENTS = { - "unet": , - "vae": "AutoencoderKL", - "text_encoder": "CLIPTextModel", - "text_encoder_2": "CLIPTextModel", - "tokenizer": "CLIPTokenizer", - "tokenizer_2": "CLIPTokenizer", - "scheduler": "DiffusionScheduler", + "LDMTextToImagePipeline": "LDMTextToImage", + "PaintByExamplePipeline": "PaintByExample", + "StableDiffusion": "stable-diffusion", } @@ -82,7 +76,16 @@ def check_valid_url(pretrained_model_link_or_path): return has_valid_url_prefix -def download_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None): +def download_model_checkpoint( + ckpt_path, + cache_dir=None, + resume_download=False, + force_download=False, + proxies=None, + local_files_only=None, + token=None, + revision=None, +): # get repo_id and (potentially nested) file path of ckpt in repo repo_id = "/".join(ckpt_path.parts[:2]) file_path = "/".join(ckpt_path.parts[2:]) @@ -125,50 +128,96 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True) return checkpoint -def infer_model_type(pipeline_class_name): - return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None) - - -def build_component(pipeline_class_name, component_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): +def build_component( + pipeline_components, + pipeline_class_name, + component_name, + original_config, + checkpoint, + checkpoint_path_or_dict, + **kwargs, +): if component_name in kwargs: return kwargs.pop(component_name, None) - if component_name == "unet": - unet = create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return unet + if component_name in pipeline_components: + return {} - if component_name == "controlnet": - controlnet = create_controlnet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return controlnet + if component_name == "unet": + unet_components = create_unet_model( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return unet_components if component_name == "vae": - vae = create_vae_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) - return vae + vae_components = create_vae_model( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return vae_components - if component_name in ["text_encoder", "text_encoder_2"]: - text_encoder = create_text_encoder_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return text_encoder + if component_name == "controlnet": + controlnet_components = create_controlnet_model( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return controlnet_components - if component_name in ["tokenizer", "tokenizer_2"]: - tokenizer = create_tokenizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return tokenizer + if component_name == "adapter": + adapter_components = create_adapter_model( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return adapter_components if component_name == "scheduler": - scheduler = create_scheduler(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return scheduler - - if component_name == "image_normalizer": - image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return image_normalizer - - if component_name == "image_normalizer": - image_normalizer = create_image_normalizer(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs) - return image_normalizer + scheduler_components = create_scheduler( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return scheduler_components + if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: + text_encoder_components = create_text_encoders_and_tokenizers( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return text_encoder_components return +def build_additional_components( + pipeline_components, + pipeline_class_name, + component_name, + original_config, + checkpoint, + checkpoint_path_or_dict, + **kwargs, +): + if component_name in kwargs: + return kwargs.pop(component_name, None) + + if component_name in pipeline_components: + return {} + + local_files_only = kwargs.pop("local_files_only", False) + + if pipeline_class_name == ["StableUnCLIPPipeline", "StableUnCLIPImg2ImgPipeline"]: + stable_unclip_components = create_stable_unclip_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return stable_unclip_components + + if pipeline_class_name == "LDMTextToImagePipeline": + ldm_text_to_image_components = create_ldm_text_to_image_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return ldm_text_to_image_components + + if pipeline_class_name == "PaintByExamplePipeline": + paint_by_example_components = create_paint_by_example_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs + ) + return paint_by_example_components + + class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. @@ -281,30 +330,22 @@ class FromSingleFileMixin: """ original_config_file = kwargs.pop("original_config_file", None) config_files = kwargs.pop("config_files", None) - cache_dir = kwargs.pop("cache_dir", None) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) + cache_dir = kwargs.pop("cache_dir", None) + local_files_only = kwargs.pop("local_files_only", None) revision = kwargs.pop("revision", None) + torch_dtype = kwargs.pop("torch_dtype", None) + use_safetensors = kwargs.pop("use_safetensors", None) + load_safety_checker = kwargs.pop("load_safety_checker", True) + extract_ema = kwargs.pop("extract_ema", False) image_size = kwargs.pop("image_size", None) scheduler_type = kwargs.pop("scheduler_type", "pndm") num_in_channels = kwargs.pop("num_in_channels", None) upcast_attention = kwargs.pop("upcast_attention", None) - load_safety_checker = kwargs.pop("load_safety_checker", True) prediction_type = kwargs.pop("prediction_type", None) - text_encoder = kwargs.pop("text_encoder", None) - text_encoder_2 = kwargs.pop("text_encoder_2", None) - vae = kwargs.pop("vae", None) - controlnet = kwargs.pop("controlnet", None) - adapter = kwargs.pop("adapter", None) - tokenizer = kwargs.pop("tokenizer", None) - tokenizer_2 = kwargs.pop("tokenizer_2", None) - - torch_dtype = kwargs.pop("torch_dtype", None) - - use_safetensors = kwargs.pop("use_safetensors", None) pipeline_name = cls.__name__ file_extension = pretrained_model_link_or_path.rsplit(".", 1)[-1] @@ -313,42 +354,7 @@ class FromSingleFileMixin: if from_safetensors and use_safetensors is False: raise ValueError("Make sure to install `safetensors` with `pip install safetensors`.") - # TODO: For now we only support stable diffusion - stable_unclip = None - model_type = None - - if pipeline_name in [ - "StableDiffusionControlNetPipeline", - "StableDiffusionControlNetImg2ImgPipeline", - "StableDiffusionControlNetInpaintPipeline", - ]: - from ..models.controlnet import ControlNetModel - from ..pipelines.controlnet.multicontrolnet import MultiControlNetModel - - # list/tuple or a single instance of ControlNetModel or MultiControlNetModel - if not ( - isinstance(controlnet, (ControlNetModel, MultiControlNetModel)) - or isinstance(controlnet, (list, tuple)) - and isinstance(controlnet[0], ControlNetModel) - ): - raise ValueError("ControlNet needs to be passed if loading from ControlNet pipeline.") - elif "StableDiffusion" in pipeline_name: - # Model type will be inferred from the checkpoint. - pass - elif pipeline_name == "StableUnCLIPPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "txt2img" - elif pipeline_name == "StableUnCLIPImg2ImgPipeline": - model_type = "FrozenOpenCLIPEmbedder" - stable_unclip = "img2img" - elif pipeline_name == "PaintByExamplePipeline": - model_type = "PaintByExample" - elif pipeline_name == "LDMTextToImagePipeline": - model_type = "LDMTextToImage" - else: - raise ValueError(f"Unhandled pipeline class: {pipeline_name}") - - has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path) + has_valid_url_prefix = check_valid_url(pretrained_model_link_or_path) # Code based on diffusers.pipelines.pipeline_utils.DiffusionPipeline.from_pretrained ckpt_path = Path(pretrained_model_link_or_path) @@ -356,9 +362,16 @@ class FromSingleFileMixin: raise ValueError( f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}" ) - pretrained_model_link_or_path = download_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision) + pretrained_model_link_or_path = download_model_checkpoint( + ckpt_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) - global_step = checkpoint["global_step"] if "global_step" in checkpoint else None # NOTE: this while loop isn't great but this controlnet checkpoint has one additional # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 @@ -370,32 +383,15 @@ class FromSingleFileMixin: pipeline_components = {} for component in component_names: - pipeline_components[component] = build_component(pipeline_name, component, checkpoint, original_config, **kwargs) + components = build_component( + pipeline_components, pipeline_name, component, checkpoint, original_config, **kwargs + ) + pipeline_components.update(components) - pipe = download_from_original_stable_diffusion_ckpt( - pretrained_model_link_or_path, - pipeline_class=cls, - model_type=model_type, - stable_unclip=stable_unclip, - controlnet=controlnet, - adapter=adapter, - from_safetensors=from_safetensors, - extract_ema=extract_ema, - image_size=image_size, - scheduler_type=scheduler_type, - num_in_channels=num_in_channels, - upcast_attention=upcast_attention, - load_safety_checker=load_safety_checker, - prediction_type=prediction_type, - text_encoder=text_encoder, - text_encoder_2=text_encoder_2, - vae=vae, - tokenizer=tokenizer, - tokenizer_2=tokenizer_2, - original_config_file=original_config_file, - config_files=config_files, - local_files_only=local_files_only, - ) + additional_components = set(pipeline_components.keys() - component_names) + if additional_components: + components = build_additional_components(pipeline_name, component, checkpoint, original_config, **kwargs) + pipeline_components.update(components) pipe = cls(**pipeline_components) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index db78981e24..cd1153e90f 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -26,16 +26,21 @@ from safetensors.torch import load_file as safe_load from transformers import ( AutoFeatureExtractor, BertTokenizerFast, + CLIPImageProcessor, + CLIPTextConfig, + CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, + CLIPVisionModel, + CLIPVisionModelWithProjection, + CLIPVisionTextModel, + CLIPVisionTextModelWithProjection, ) -from ...models import ( - AutoencoderKL, - PriorTransformer, - UNet2DConditionModel, -) -from ...schedulers import ( +from ..models import AutoencoderKL, ControlNetModel, PriorTransformer, UNet2DConditionModel +from ..pipelines.pipeline_utils import DiffusionPipeline +from ..pipelines.stable_unclip.stable_unclip_image_normalizer import StableUnCLIPImageNormalizer +from ..schedulers import ( DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, @@ -46,9 +51,8 @@ from ...schedulers import ( PNDMScheduler, UnCLIPScheduler, ) -from ...utils import is_accelerate_available, is_omegaconf_available, logging -from ...utils.import_utils import BACKENDS_MAPPING -from ..pipeline_utils import DiffusionPipeline +from ..utils import is_accelerate_available, is_omegaconf_available, logging +from ..utils.import_utils import BACKENDS_MAPPING from .safety_checker import StableDiffusionSafetyChecker @@ -62,7 +66,7 @@ CONFIG_URLS = { "v1": "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", "v2": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml", "xl": "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml", - "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml" + "upscale": "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml", } CHECKPOINT_KEY_NAMES = { @@ -71,6 +75,20 @@ CHECKPOINT_KEY_NAMES = { "xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias", } +SCHEDULER_DEFAULT_CONFIG = { + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "beta_end": 0.012, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": False, + "skip_prk_steps": True, + "steps_offset": 1, + "timestep_spacing": "leading", +} + textenc_conversion_lst = [ ("positional_embedding", "text_model.embeddings.position_embedding.weight"), ("token_embedding.weight", "text_model.embeddings.token_embedding.weight"), @@ -109,7 +127,7 @@ def fetch_original_config_file_from_url(checkpoint): else: config_url = CONFIG_URLS["v1"] - #TODO: Add upscale config + # TODO: Add upscale config original_config_file = BytesIO(requests.get(config_url).content) @@ -129,7 +147,7 @@ def fetch_original_config_file_from_file(checkpoint, config_files: list): if "xl_refiner" in config_files: return config_files["xl_refiner"] - #TODO: Add upscale config + # TODO: Add upscale config return @@ -162,12 +180,21 @@ def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True) return checkpoint -def set_model_type(original_config, model_type=None): +def infer_model_type(pipeline_class_name, original_config, model_type=None): if model_type is not None: return model_type - has_cond_stage_config = "cond_stage_config" in original_config.model.params and original_config.model.params.cond_stage_config is not None - has_network_config = "network_config" in original_config.model.params and original_config.model.params.network_config is not None + if pipeline_class_name in ["StableUnCLIPPipeline", "StableUnCLIPImg2ImgPipeline"]: + model_type = "FrozenOpenCLIPEmbedder" + return model_type + + has_cond_stage_config = ( + "cond_stage_config" in original_config.model.params + and original_config.model.params.cond_stage_config is not None + ) + has_network_config = ( + "network_config" in original_config.model.params and original_config.model.params.network_config is not None + ) if has_cond_stage_config: model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] @@ -185,6 +212,11 @@ def set_model_type(original_config, model_type=None): return model_type + +def get_default_scheduler_config(): + return SCHEDULER_DEFAULT_CONFIG + + def shave_segments(path, n_shave_prefix_segments=1): """ Removes segments. Positive values shave the first segments, negative shave the last segments. @@ -350,6 +382,7 @@ def conv_attn_to_linear(checkpoint): if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0] + def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): """ Creates a config for the diffusers based on the config of the LDM model. @@ -971,6 +1004,7 @@ def convert_controlnet_checkpoint( return controlnet + def convert_open_clip_checkpoint( checkpoint, config_name, @@ -1053,22 +1087,173 @@ def convert_open_clip_checkpoint( return text_model +def stable_unclip_image_encoder(original_config, local_files_only=False): + """ + Returns the image processor and clip image encoder for the img2img unclip pipeline. + + We currently know of two types of stable unclip models which separately use the clip and the openclip image + encoders. + """ + + image_embedder_config = original_config.model.params.embedder_config + + sd_clip_image_embedder_class = image_embedder_config.target + sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] + + if sd_clip_image_embedder_class == "ClipImageEmbedder": + clip_model_name = image_embedder_config.params.model + + if clip_model_name == "ViT-L/14": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "openai/clip-vit-large-patch14", local_files_only=local_files_only + ) + else: + raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") + + elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": + feature_extractor = CLIPImageProcessor() + image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", local_files_only=local_files_only + ) + else: + raise NotImplementedError( + f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" + ) + + return feature_extractor, image_encoder + + +def convert_paint_by_example_checkpoint(checkpoint, local_files_only=False): + config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14", local_files_only=local_files_only) + model = PaintByExampleImageEncoder(config) + + keys = list(checkpoint.keys()) + + text_model_dict = {} + + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # load clip vision + model.model.load_state_dict(text_model_dict) + + # load mapper + keys_mapper = { + k[len("cond_stage_model.mapper.res") :]: v + for k, v in checkpoint.items() + if k.startswith("cond_stage_model.mapper") + } + + MAPPING = { + "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], + "attn.c_proj": ["attn1.to_out.0"], + "ln_1": ["norm1"], + "ln_2": ["norm3"], + "mlp.c_fc": ["ff.net.0.proj"], + "mlp.c_proj": ["ff.net.2"], + } + + mapped_weights = {} + for key, value in keys_mapper.items(): + prefix = key[: len("blocks.i")] + suffix = key.split(prefix)[-1].split(".")[-1] + name = key.split(prefix)[-1].split(suffix)[0][1:-1] + mapped_names = MAPPING[name] + + num_splits = len(mapped_names) + for i, mapped_name in enumerate(mapped_names): + new_name = ".".join([prefix, mapped_name, suffix]) + shape = value.shape[0] // num_splits + mapped_weights[new_name] = value[i * shape : (i + 1) * shape] + + model.mapper.load_state_dict(mapped_weights) + + # load final layer norm + model.final_layer_norm.load_state_dict( + { + "bias": checkpoint["cond_stage_model.final_ln.bias"], + "weight": checkpoint["cond_stage_model.final_ln.weight"], + } + ) + + # load final proj + model.proj_out.load_state_dict( + { + "bias": checkpoint["proj_out.bias"], + "weight": checkpoint["proj_out.weight"], + } + ) + + # load uncond vector + model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) + return model + + +def stable_unclip_image_noising_components( + original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None +): + """ + Returns the noising components for the img2img and txt2img unclip pipelines. + + Converts the stability noise augmentor into + 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats + 2. a `DDPMScheduler` for holding the noise schedule + + If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. + """ + noise_aug_config = original_config.model.params.noise_aug_config + noise_aug_class = noise_aug_config.target + noise_aug_class = noise_aug_class.split(".")[-1] + + if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": + noise_aug_config = noise_aug_config.params + embedding_dim = noise_aug_config.timestep_dim + max_noise_level = noise_aug_config.noise_schedule_config.timesteps + beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule + + image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) + image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) + + if "clip_stats_path" in noise_aug_config: + if clip_stats_path is None: + raise ValueError("This stable unclip config requires a `clip_stats_path`") + + clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) + clip_mean = clip_mean[None, :] + clip_std = clip_std[None, :] + + clip_stats_state_dict = { + "mean": clip_mean, + "std": clip_std, + } + + image_normalizer.load_state_dict(clip_stats_state_dict) + else: + raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") + + return image_normalizer, image_noising_scheduler + + def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, image_size, **kwargs): if "num_in_channels" in kwargs: - num_in_channels = kwargs.pop("num_in_channels") + num_in_channels = kwargs.get("num_in_channels") + elif pipeline_class_name in [ "StableDiffusionInpaintPipeline", "StableDiffusionXLInpaintPipeline", - "StableDiffusionXLControlNetInpaintPipeline"]: + "StableDiffusionXLControlNetInpaintPipeline", + ]: num_in_channels = 9 + elif pipeline_class_name == "StableDiffusionUpscalePipeline": num_in_channels = 7 + else: num_in_channels = 4 - if "upcast_attention" in kwargs: - upcast_attention = kwargs.pop("upcast_attention") - + upcast_attention = kwargs.get("upcast_attention", False) extract_ema = kwargs.get("extract_ema", False) unet_config = create_unet_diffusers_config(original_config, image_size=image_size) @@ -1092,9 +1277,8 @@ def create_unet_model(pipeline_class_name, original_config, checkpoint, checkpoi return unet -def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs): +def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, **kwargs): vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" diffusers_format_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config) ctx = init_empty_weights if is_accelerate_available() else nullcontext with ctx(): @@ -1109,6 +1293,269 @@ def create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model return vae +def create_text_encoder_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs +): + model_type = infer_model_type(pipeline_class_name, original_config) + local_files_only = kwargs.get("local_files_only", False) + + if model_type == "FrozenOpenCLIPEmbedder": + config_name = "stabilityai/stable-diffusion-2" + config_kwargs = {"subfolder": "text_encoder"} + + try: + text_encoder = convert_open_clip_checkpoint( + checkpoint, config_name, local_files_only=local_files_only, **config_kwargs + ) + tokenizer = CLIPTokenizer.from_pretrained( + config_name, subfolder="tokenizer", local_files_only=local_files_only + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "FrozenCLIPEmbedder": + try: + config_name = "openai/clip-vit-large-patch14" + text_encoder = convert_ldm_clip_checkpoint( + checkpoint, local_files_only=local_files_only, text_encoder=None + ) + tokenizer = CLIPTokenizer.from_pretrained( + config_name, subfolder="tokenizer", local_files_only=local_files_only + ) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." + ) + else: + return {"text_encoder": text_encoder, "tokenizer": tokenizer} + + elif model_type == "SDXL-Refiner": + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + config_kwargs = {"projection_dim": 1280} + prefix = "conditioner.embedders.0.model." + + try: + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + else: + return { + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + elif model_type == "SDXL": + try: + config_name = "openai/clip-vit-large-patch14" + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + text_encoder = convert_ldm_clip_checkpoint(checkpoint, local_files_only=local_files_only) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder and tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + try: + config_name = "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k" + prefix = "conditioner.embedders.1.model." + tokenizer_2 = CLIPTokenizer.from_pretrained(config_name, pad_token="!", local_files_only=local_files_only) + text_encoder_2 = convert_open_clip_checkpoint( + checkpoint, + config_name, + prefix=prefix, + has_projection=True, + local_files_only=local_files_only, + **config_kwargs, + ) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the text_encoder_2 and tokenizer_2 in the following path: {config_name} with `pad_token` set to '!'." + ) + + return { + "tokenizer": tokenizer, + "text_encoder": text_encoder, + "tokenizer_2": tokenizer_2, + "text_encoder_2": text_encoder_2, + } + + return + + +def create_scheduler_component(pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs): + scheduler_config = get_default_scheduler_config() + model_type = infer_model_type(pipeline_class_name, original_config) + + scheduler_type = kwargs.get("scheduler_type", "ddim") + prediction_type = kwargs.get("prediction_type", None) + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + + num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 + scheduler_config["num_train_timesteps"] = num_train_timesteps + + if ( + "parameterization" in original_config["model"]["params"] + and original_config["model"]["params"]["parameterization"] == "v" + ): + if prediction_type is None: + # NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"` + # as it relies on a brittle global step parameter here + prediction_type = "epsilon" if global_step == 875000 else "v_prediction" + + else: + prediction_type = prediction_type or "epsilon" + + scheduler_config["prediction_type"] = prediction_type + + if model_type in ["SDXL", "SDXL-Refiner"]: + scheduler_type = "euler" + + else: + beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02 + beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085 + scheduler_config["beta_start"] = beta_start + scheduler_config["beta_end"] = beta_end + scheduler_config["beta_schedule"] = "scaled_linear" + scheduler_config["clip_sample"] = False + scheduler_config["set_alpha_to_one"] = False + + scheduler_type = "ddim" + + if scheduler_type == "pndm": + scheduler_config["skip_prk_steps"] = True + scheduler = PNDMScheduler.from_config(scheduler_config) + + elif scheduler_type == "lms": + scheduler = LMSDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "heun": + scheduler = HeunDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler": + scheduler = EulerDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "euler-ancestral": + scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) + + elif scheduler_type == "dpm": + scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) + + elif scheduler_type == "ddim": + scheduler = DDIMScheduler.from_config(scheduler_config) + + else: + raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!") + + return {"scheduler": scheduler} + + +def create_stable_unclip_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs +): + components = {} + + local_files_only = kwargs.get("local_files_only", False) + clip_stats_path = kwargs.get("clip_stats_path", None) + + image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( + original_config, + clip_stats_path=clip_stats_path, + ) + + if pipeline_class_name == "StableUnCLIPPipeline": + stable_unclip_prior = kwargs.get("stable_unclip_prior", None) + if stable_unclip_prior is None and stable_unclip_prior != "karlo": + raise NotImplementedError(f"Unknown prior for Stable UnCLIP model: {stable_unclip_prior}") + + try: + config_name = "kakaobrain/karlo-v1-alpha" + prior = PriorTransformer.from_pretrained(config_name, subfolder="prior", local_files_only=local_files_only) + except Exception as e: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the prior in the following path: '{config_name}'." + ) + + try: + config_name = "openai/clip-vit-large-patch14" + prior_tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + prior_text_encoder = CLIPTextModelWithProjection.from_pretrained( + config_name, local_files_only=local_files_only + ) + prior_scheduler = DDPMScheduler.from_pretrained( + config_name, subfolder="prior_scheduler", local_files_only=local_files_only + ) + + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: '{config_name}'." + ) + else: + return { + "prior": prior, + "prior_tokenizer": prior_tokenizer, + "prior_text_encoder": prior_text_encoder, + "prior_scheduler": prior_scheduler, + "image_normalizer": image_normalizer, + "image_noise_scheduler": image_noising_scheduler, + } + + else: + feature_extractor, image_encoder = stable_unclip_image_encoder(original_config) + + return { + "feature_extractor": feature_extractor, + "image_encoder": image_encoder, + "image_normalizer": image_normalizer, + "image_noising_scheduler": image_noising_scheduler, + } + + return + + +def create_paint_by_example_components( + pipeline_class_name, original_config, checkpoint, checkpoint_path_or_dict, **kwargs +): + local_files_only = kwargs.get("local_files_only", False) + image_encoder = convert_paint_by_example_checkpoint(checkpoint) + + try: + config_name = "openai/clip-vit-large-patch14" + tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the tokenizer in the following path: 'openai/clip-vit-large-patch14'." + ) + + try: + config_name = "CompVis/stable-diffusion-safety-checker" + feature_extractor = AutoFeatureExtractor.from_pretrained(config_name, local_files_only=local_files_only) + except Exception: + raise ValueError( + f"With local_files_only set to {local_files_only}, you must first locally save the feature_extractor in the following path: 'CompVis/stable-diffusion-safety-checker'." + ) + + return { + "image_encoder": image_encoder, + "tokenizer": tokenizer, + "feature_extractor": feature_extractor, + } + def download_from_original_stable_diffusion_ckpt( checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], @@ -1137,7 +1584,7 @@ def download_from_original_stable_diffusion_ckpt( tokenizer=None, tokenizer_2=None, config_files=None, - **kwargs + **kwargs, ) -> DiffusionPipeline: """ Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` @@ -1238,7 +1685,7 @@ def download_from_original_stable_diffusion_ckpt( checkpoint = checkpoint["state_dict"] original_config = fetch_original_config(checkpoint, config_files) - model_type = set_model_type(original_config, model_type) + model_type = infer_model_type(original_config, model_type) unet = create_unet_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) vae = create_vae_model(original_config, checkpoint, checkpoint_path_or_dict, model_type, image_size, **kwargs) @@ -1696,4 +2143,3 @@ def download_from_original_stable_diffusion_ckpt( pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe -