From 53605ed00add124dc4d9d000cebb230555a788cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Tolga=20Cang=C3=B6z?= <46008593+standardAI@users.noreply.github.com> Date: Fri, 23 Feb 2024 18:27:37 +0300 Subject: [PATCH] [`Refactor`] `save_model_card` function in `text_to_image` examples (#7051) * Refactor save_model_card function to handle images and repo_folder parameters * Discard changes to examples/text_to_image/train_text_to_image.py * Discard changes to examples/text_to_image/train_text_to_image_lora_sdxl.py * Update train_text_to_image_lora.py * Update train_text_to_image_sdxl.py --- examples/text_to_image/train_text_to_image_lora.py | 13 +++++++++---- examples/text_to_image/train_text_to_image_sdxl.py | 7 ++++--- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 39590fa866..3483850074 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -58,12 +58,17 @@ logger = get_logger(__name__, log_level="INFO") def save_model_card( - repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None + repo_id: str, + images: list = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, ): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" model_description = f""" # LoRA text2image fine-tuning - {repo_id} diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 04f8c3dba4..2d77e9c8bf 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -74,9 +74,10 @@ def save_model_card( vae_path: str = None, ): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" model_description = f""" # Text-to-image finetuning - {repo_id}