diff --git a/scripts/convert_original_controlnet_to_diffusers.py b/scripts/convert_original_controlnet_to_diffusers.py new file mode 100644 index 0000000000..a9e05abd4c --- /dev/null +++ b/scripts/convert_original_controlnet_to_diffusers.py @@ -0,0 +1,91 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """ + +import argparse + +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument( + "--original_config_file", + type=str, + required=True, + help="The YAML config file corresponding to the original architecture.", + ) + parser.add_argument( + "--num_in_channels", + default=None, + type=int, + help="The number of input channels. If `None` number of input channels will be automatically inferred.", + ) + parser.add_argument( + "--image_size", + default=512, + type=int, + help=( + "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" + " Base. Use 768 for Stable Diffusion v2." + ), + ) + parser.add_argument( + "--extract_ema", + action="store_true", + help=( + "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" + " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" + " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." + ), + ) + parser.add_argument( + "--upcast_attention", + action="store_true", + help=( + "Whether the attention computation should always be upcasted. This is necessary when running stable" + " diffusion 2.1." + ), + ) + parser.add_argument( + "--from_safetensors", + action="store_true", + help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", + ) + parser.add_argument( + "--to_safetensors", + action="store_true", + help="Whether to store pipeline in safetensors format or not.", + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + args = parser.parse_args() + + controlnet = download_controlnet_from_original_ckpt( + checkpoint_path=args.checkpoint_path, + original_config_file=args.original_config_file, + image_size=args.image_size, + extract_ema=args.extract_ema, + num_in_channels=args.num_in_channels, + upcast_attention=args.upcast_attention, + from_safetensors=args.from_safetensors, + device=args.device, + ) + + controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 857fb29676..ef4598433f 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -954,6 +954,25 @@ def stable_unclip_image_noising_components( return image_normalizer, image_noising_scheduler +def convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema +): + ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) + ctrlnet_config["upcast_attention"] = upcast_attention + + ctrlnet_config.pop("sample_size") + + controlnet_model = ControlNetModel(**ctrlnet_config) + + converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( + checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True + ) + + controlnet_model.load_state_dict(converted_ctrl_checkpoint) + + return controlnet_model + + def download_from_original_stable_diffusion_ckpt( checkpoint_path: str, original_config_file: str = None, @@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt( print("global_step key not found in model") global_step = None - if "state_dict" in checkpoint: + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] if original_config_file is None: @@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt( if image_size is None: image_size = 512 + if controlnet is None: + controlnet = "control_stage_config" in original_config.model.params + + if controlnet: + controlnet_model = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + num_train_timesteps = original_config.model.params.timesteps beta_start = original_config.model.params.linear_start beta_end = original_config.model.params.linear_end @@ -1143,27 +1172,34 @@ def download_from_original_stable_diffusion_ckpt( model_type = original_config.model.params.cond_stage_config.target.split(".")[-1] logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}") - if controlnet is None: - controlnet = "control_stage_config" in original_config.model.params - - if controlnet and model_type != "FrozenCLIPEmbedder": - raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'") - if model_type == "FrozenOpenCLIPEmbedder": text_model = convert_open_clip_checkpoint(checkpoint) tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer") if stable_unclip is None: - pipe = StableDiffusionPipeline( - vae=vae, - text_encoder=text_model, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + if controlnet: + pipe = StableDiffusionControlNetPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + controlnet=controlnet_model, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + else: + pipe = StableDiffusionPipeline( + vae=vae, + text_encoder=text_model, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) else: image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components( original_config, clip_stats_path=clip_stats_path, device=device @@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt( feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker") if controlnet: - # Convert the ControlNetModel model. - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - controlnet_model = ControlNetModel(**ctrlnet_config) - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True - ) - controlnet_model.load_state_dict(converted_ctrl_checkpoint) - pipe = StableDiffusionControlNetPipeline( vae=vae, text_encoder=text_model, @@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt( pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler) return pipe + + +def download_controlnet_from_original_ckpt( + checkpoint_path: str, + original_config_file: str, + image_size: int = 512, + extract_ema: bool = False, + num_in_channels: Optional[int] = None, + upcast_attention: Optional[bool] = None, + device: str = None, + from_safetensors: bool = False, +) -> StableDiffusionPipeline: + if not is_omegaconf_available(): + raise ValueError(BACKENDS_MAPPING["omegaconf"][1]) + + from omegaconf import OmegaConf + + if from_safetensors: + if not is_safetensors_available(): + raise ValueError(BACKENDS_MAPPING["safetensors"][1]) + + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + else: + checkpoint = torch.load(checkpoint_path, map_location=device) + + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = OmegaConf.load(original_config_file) + + if num_in_channels is not None: + original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels + + if "control_stage_config" not in original_config.model.params: + raise ValueError("`control_stage_config` not present in original config") + + controlnet_model = convert_controlnet_checkpoint( + checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + ) + + return controlnet_model