mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
controlnet sd 2.1 checkpoint conversions (#2593)
* controlnet sd 2.1 checkpoint conversions * remove global_step -> make config file mandatory
This commit is contained in:
91
scripts/convert_original_controlnet_to_diffusers.py
Normal file
91
scripts/convert_original_controlnet_to_diffusers.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """
|
||||
|
||||
import argparse
|
||||
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_in_channels",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=512,
|
||||
type=int,
|
||||
help=(
|
||||
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
|
||||
" Base. Use 768 for Stable Diffusion v2."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract_ema",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
|
||||
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
|
||||
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upcast_attention",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether the attention computation should always be upcasted. This is necessary when running stable"
|
||||
" diffusion 2.1."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--from_safetensors",
|
||||
action="store_true",
|
||||
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--to_safetensors",
|
||||
action="store_true",
|
||||
help="Whether to store pipeline in safetensors format or not.",
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||
args = parser.parse_args()
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
checkpoint_path=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
image_size=args.image_size,
|
||||
extract_ema=args.extract_ema,
|
||||
num_in_channels=args.num_in_channels,
|
||||
upcast_attention=args.upcast_attention,
|
||||
from_safetensors=args.from_safetensors,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||
@@ -954,6 +954,25 @@ def stable_unclip_image_noising_components(
|
||||
return image_normalizer, image_noising_scheduler
|
||||
|
||||
|
||||
def convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
):
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
||||
)
|
||||
|
||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
return controlnet_model
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
@@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
print("global_step key not found in model")
|
||||
global_step = None
|
||||
|
||||
if "state_dict" in checkpoint:
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
if original_config_file is None:
|
||||
@@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
|
||||
if controlnet is None:
|
||||
controlnet = "control_stage_config" in original_config.model.params
|
||||
|
||||
if controlnet:
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
@@ -1143,27 +1172,34 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
|
||||
|
||||
if controlnet is None:
|
||||
controlnet = "control_stage_config" in original_config.model.params
|
||||
|
||||
if controlnet and model_type != "FrozenCLIPEmbedder":
|
||||
raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'")
|
||||
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
|
||||
|
||||
if stable_unclip is None:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
if controlnet:
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
controlnet=controlnet_model,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
else:
|
||||
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
|
||||
original_config, clip_stats_path=clip_stats_path, device=device
|
||||
@@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||
|
||||
if controlnet:
|
||||
# Convert the ControlNetModel model.
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
||||
)
|
||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||
|
||||
pipe = StableDiffusionControlNetPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_model,
|
||||
@@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
return pipe
|
||||
|
||||
|
||||
def download_controlnet_from_original_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str,
|
||||
image_size: int = 512,
|
||||
extract_ema: bool = False,
|
||||
num_in_channels: Optional[int] = None,
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
) -> StableDiffusionPipeline:
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
if not is_safetensors_available():
|
||||
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
|
||||
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
|
||||
|
||||
if "control_stage_config" not in original_config.model.params:
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
return controlnet_model
|
||||
|
||||
Reference in New Issue
Block a user