mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
small tweaks for parsing thibaudz controlnet checkpoints (#3657)
This commit is contained in:
@@ -75,6 +75,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||
|
||||
# small workaround to get argparser to parse a boolean input as either true _or_ false
|
||||
def parse_bool(string):
|
||||
if string == "True":
|
||||
return True
|
||||
elif string == "False":
|
||||
return False
|
||||
else:
|
||||
raise ValueError(f"could not parse string as bool {string}")
|
||||
|
||||
parser.add_argument(
|
||||
"--use_linear_projection", help="Override for use linear projection", required=False, type=parse_bool
|
||||
)
|
||||
|
||||
parser.add_argument("--cross_attention_dim", help="Override for cross attention_dim", required=False, type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
controlnet = download_controlnet_from_original_ckpt(
|
||||
@@ -86,6 +102,8 @@ if __name__ == "__main__":
|
||||
upcast_attention=args.upcast_attention,
|
||||
from_safetensors=args.from_safetensors,
|
||||
device=args.device,
|
||||
use_linear_projection=args.use_linear_projection,
|
||||
cross_attention_dim=args.cross_attention_dim,
|
||||
)
|
||||
|
||||
controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||
|
||||
@@ -339,41 +339,46 @@ def create_ldm_bert_config(original_config):
|
||||
return config
|
||||
|
||||
|
||||
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
||||
def convert_ldm_unet_checkpoint(
|
||||
checkpoint, config, path=None, extract_ema=False, controlnet=False, skip_extract_state_dict=False
|
||||
):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
"""
|
||||
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
if controlnet:
|
||||
unet_key = "control_model."
|
||||
if skip_extract_state_dict:
|
||||
unet_state_dict = checkpoint
|
||||
else:
|
||||
unet_key = "model.diffusion_model."
|
||||
# extract state_dict for UNet
|
||||
unet_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
print(
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
if controlnet:
|
||||
unet_key = "control_model."
|
||||
else:
|
||||
unet_key = "model.diffusion_model."
|
||||
|
||||
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
||||
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
||||
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
||||
)
|
||||
for key in keys:
|
||||
if key.startswith("model.diffusion_model"):
|
||||
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
||||
else:
|
||||
if sum(k.startswith("model_ema") for k in keys) > 100:
|
||||
print(
|
||||
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
||||
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
for key in keys:
|
||||
if key.startswith(unet_key):
|
||||
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
||||
|
||||
new_checkpoint = {}
|
||||
|
||||
@@ -956,17 +961,42 @@ def stable_unclip_image_noising_components(
|
||||
|
||||
|
||||
def convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
use_linear_projection=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
||||
ctrlnet_config["upcast_attention"] = upcast_attention
|
||||
|
||||
ctrlnet_config.pop("sample_size")
|
||||
|
||||
if use_linear_projection is not None:
|
||||
ctrlnet_config["use_linear_projection"] = use_linear_projection
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
ctrlnet_config["cross_attention_dim"] = cross_attention_dim
|
||||
|
||||
controlnet_model = ControlNetModel(**ctrlnet_config)
|
||||
|
||||
# Some controlnet ckpt files are distributed independently from the rest of the
|
||||
# model components i.e. https://huggingface.co/thibaud/controlnet-sd21/
|
||||
if "time_embed.0.weight" in checkpoint:
|
||||
skip_extract_state_dict = True
|
||||
else:
|
||||
skip_extract_state_dict = False
|
||||
|
||||
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
||||
checkpoint,
|
||||
ctrlnet_config,
|
||||
path=checkpoint_path,
|
||||
extract_ema=extract_ema,
|
||||
controlnet=True,
|
||||
skip_extract_state_dict=skip_extract_state_dict,
|
||||
)
|
||||
|
||||
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
||||
@@ -1344,6 +1374,8 @@ def download_controlnet_from_original_ckpt(
|
||||
upcast_attention: Optional[bool] = None,
|
||||
device: str = None,
|
||||
from_safetensors: bool = False,
|
||||
use_linear_projection: Optional[bool] = None,
|
||||
cross_attention_dim: Optional[bool] = None,
|
||||
) -> DiffusionPipeline:
|
||||
if not is_omegaconf_available():
|
||||
raise ValueError(BACKENDS_MAPPING["omegaconf"][1])
|
||||
@@ -1381,7 +1413,14 @@ def download_controlnet_from_original_ckpt(
|
||||
raise ValueError("`control_stage_config` not present in original config")
|
||||
|
||||
controlnet_model = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
checkpoint,
|
||||
original_config,
|
||||
checkpoint_path,
|
||||
image_size,
|
||||
upcast_attention,
|
||||
extract_ema,
|
||||
use_linear_projection=use_linear_projection,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
return controlnet_model
|
||||
|
||||
Reference in New Issue
Block a user