1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

small tweaks for parsing thibaudz controlnet checkpoints (#3657)

This commit is contained in:
Will Berman
2023-06-05 10:24:31 -07:00
committed by GitHub
parent 5990014700
commit 462956be7b
2 changed files with 86 additions and 29 deletions

View File

@@ -75,6 +75,22 @@ if __name__ == "__main__":
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
# small workaround to get argparser to parse a boolean input as either true _or_ false
def parse_bool(string):
if string == "True":
return True
elif string == "False":
return False
else:
raise ValueError(f"could not parse string as bool {string}")
parser.add_argument(
"--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool
)
parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int)
args = parser.parse_args()
controlnet = download_controlnet_from_original_ckpt(
@@ -86,6 +102,8 @@ if __name__ == "__main__":
upcast_attention=args.upcast_attention,
from_safetensors=args.from_safetensors,
device=args.device,
use_linear_projection=args.use_linear_projection,
cross_attention_dim=args.cross_attention_dim,
)
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

View File

@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
return config
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
def convert_ldm_unet_checkpoint(
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
if controlnet:
unet_key = "control_model."
if skip_extract_state_dict:
unet_state_dict = checkpoint
else:
unet_key = "model.diffusion_model."
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
if controlnet:
unet_key = "control_model."
else:
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
def convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=None,
cross_attention_dim=None,
):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
if use_linear_projection is not None:
ctrlnet_config["use_linear_projection"] = use_linear_projection
if cross_attention_dim is not None:
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
controlnet_model = ControlNetModel(**ctrlnet_config)
# Some controlnet ckpt files are distributed independently from the rest of the
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
if "time_embed.0.weight" in checkpoint:
skip_extract_state_dict = True
else:
skip_extract_state_dict = False
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
checkpoint,
ctrlnet_config,
path=checkpoint_path,
extract_ema=extract_ema,
controlnet=True,
skip_extract_state_dict=skip_extract_state_dict,
)
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
upcast_attention: Optional[bool] = None,
device: str = None,
from_safetensors: bool = False,
use_linear_projection: Optional[bool] = None,
cross_attention_dim: Optional[bool] = None,
) -> DiffusionPipeline:
if not is_omegaconf_available():
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
raise ValueError("`control_stage_config` not present in original config")
controlnet_model = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
checkpoint,
original_config,
checkpoint_path,
image_size,
upcast_attention,
extract_ema,
use_linear_projection=use_linear_projection,
cross_attention_dim=cross_attention_dim,
)
return controlnet_model