From 777063e1bfda024e7dfc3a9ba2acb20552aec6bc Mon Sep 17 00:00:00 2001 From: Bhavay Malhotra <56443877+Bhavay-2001@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:39:51 +0530 Subject: [PATCH] Update textual_inversion.py (#6952) * Update textual_inversion.py * Apply suggestions from code review * Update textual_inversion.py * Update textual_inversion.py * Update textual_inversion.py * Update textual_inversion.py * Update examples/textual_inversion/textual_inversion.py Co-authored-by: Sayak Paul * Update textual_inversion.py * styling --------- Co-authored-by: Sayak Paul --- .../textual_inversion/textual_inversion.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 6f02aaeae4..02988dd139 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -53,6 +53,7 @@ from diffusers import ( ) from diffusers.optimization import get_scheduler from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -84,32 +85,30 @@ check_min_version("0.27.0.dev0") logger = get_logger(__name__) -def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None): +def save_model_card(repo_id: str, images: list = None, base_model: str = None, repo_folder: str = None): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" - - yaml = f""" ---- -license: creativeml-openrail-m -base_model: {base_model} -tags: -- stable-diffusion -- stable-diffusion-diffusers -- text-to-image -- diffusers -- textual_inversion -inference: true ---- - """ - model_card = f""" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" + model_description = f""" # Textual inversion text2image fine-tuning - {repo_id} These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n {img_str} """ - with open(os.path.join(repo_folder, "README.md"), "w") as f: - f.write(yaml + model_card) + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "textual_inversion"] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):