mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Write model card in controlnet training script (#3229)
Write model card in controlnet training script.
This commit is contained in:
@@ -60,6 +60,17 @@ check_min_version("0.17.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def image_grid(imgs, rows, cols):
|
||||
assert len(imgs) == rows * cols
|
||||
|
||||
w, h = imgs[0].size
|
||||
grid = Image.new("RGB", size=(cols * w, rows * h))
|
||||
|
||||
for i, img in enumerate(imgs):
|
||||
grid.paste(img, box=(i % cols * w, i // cols * h))
|
||||
return grid
|
||||
|
||||
|
||||
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
@@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
return image_logs
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
@@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
if image_logs is not None:
|
||||
img_str = "You can find some example images below.\n"
|
||||
for i, log in enumerate(image_logs):
|
||||
images = log["images"]
|
||||
validation_prompt = log["validation_prompt"]
|
||||
validation_image = log["validation_image"]
|
||||
validation_image.save(os.path.join(repo_folder, "image_control.png"))
|
||||
img_str += f"prompt: {validation_prompt}\n"
|
||||
images = [validation_image] + images
|
||||
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- controlnet
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# controlnet-{repo_id}
|
||||
|
||||
These are controlnet weights trained on {base_model} with new type of conditioning.
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
||||
parser.add_argument(
|
||||
@@ -943,6 +993,7 @@ def main(args):
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
|
||||
image_logs = None
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(controlnet):
|
||||
@@ -1014,7 +1065,7 @@ def main(args):
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
log_validation(
|
||||
image_logs = log_validation(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
@@ -1040,6 +1091,12 @@ def main(args):
|
||||
controlnet.save_pretrained(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
image_logs=image_logs,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
Reference in New Issue
Block a user