1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

controlnet sd 2.1 checkpoint conversions (#2593)

* controlnet sd 2.1 checkpoint conversions

* remove global_step -> make config file mandatory
This commit is contained in:
Will Berman
2023-03-10 08:22:02 -08:00
committed by GitHub
parent f1ab955f64
commit a28acb5dcc
2 changed files with 196 additions and 30 deletions

View File

@@ -0,0 +1,91 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for stable diffusion checkpoints which _only_ contain a contrlnet. """
import argparse
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_controlnet_from_original_ckpt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--original_config_file",
type=str,
required=True,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--image_size",
default=512,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument(
"--upcast_attention",
action="store_true",
help=(
"Whether the attention computation should always be upcasted. This is necessary when running stable"
" diffusion 2.1."
),
)
parser.add_argument(
"--from_safetensors",
action="store_true",
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
)
parser.add_argument(
"--to_safetensors",
action="store_true",
help="Whether to store pipeline in safetensors format or not.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
args = parser.parse_args()
controlnet = download_controlnet_from_original_ckpt(
checkpoint_path=args.checkpoint_path,
original_config_file=args.original_config_file,
image_size=args.image_size,
extract_ema=args.extract_ema,
num_in_channels=args.num_in_channels,
upcast_attention=args.upcast_attention,
from_safetensors=args.from_safetensors,
device=args.device,
)
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

View File

@@ -954,6 +954,25 @@ def stable_unclip_image_noising_components(
return image_normalizer, image_noising_scheduler
def convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
controlnet_model = ControlNetModel(**ctrlnet_config)
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
)
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
return controlnet_model
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
@@ -1042,7 +1061,9 @@ def download_from_original_stable_diffusion_ckpt(
print("global_step key not found in model")
global_step = None
if "state_dict" in checkpoint:
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
if original_config_file is None:
@@ -1084,6 +1105,14 @@ def download_from_original_stable_diffusion_ckpt(
if image_size is None:
image_size = 512
if controlnet is None:
controlnet = "control_stage_config" in original_config.model.params
if controlnet:
controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
)
num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end
@@ -1143,27 +1172,34 @@ def download_from_original_stable_diffusion_ckpt(
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
logger.debug(f"no `model_type` given, `model_type` inferred as: {model_type}")
if controlnet is None:
controlnet = "control_stage_config" in original_config.model.params
if controlnet and model_type != "FrozenCLIPEmbedder":
raise ValueError("`controlnet`=True only supports `model_type`='FrozenCLIPEmbedder'")
if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
if stable_unclip is None:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
if controlnet:
pipe = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
controlnet=controlnet_model,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
else:
image_normalizer, image_noising_scheduler = stable_unclip_image_noising_components(
original_config, clip_stats_path=clip_stats_path, device=device
@@ -1238,19 +1274,6 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
if controlnet:
# Convert the ControlNetModel model.
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
controlnet_model = ControlNetModel(**ctrlnet_config)
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
)
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
pipe = StableDiffusionControlNetPipeline(
vae=vae,
text_encoder=text_model,
@@ -1278,3 +1301,55 @@ def download_from_original_stable_diffusion_ckpt(
pipe = LDMTextToImagePipeline(vqvae=vae, bert=text_model, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
return pipe
def download_controlnet_from_original_ckpt(
checkpoint_path: str,
original_config_file: str,
image_size: int = 512,
extract_ema: bool = False,
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
device: str = None,
from_safetensors: bool = False,
) -> StableDiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
from omegaconf import OmegaConf
if from_safetensors:
if not is_safetensors_available():
raise ValueError(BACKENDS_MAPPING["safetensors"][1])
from safetensors import safe_open
checkpoint = {}
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
for key in f.keys():
checkpoint[key] = f.get_tensor(key)
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None:
original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = num_in_channels
if "control_stage_config" not in original_config.model.params:
raise ValueError("`control_stage_config` not present in original config")
controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
)
return controlnet_model