From 4b89aeffe154197aae8e19d58fe1bb1ded813da2 Mon Sep 17 00:00:00 2001 From: Piyush Thakur <53268607+cosmo3769@users.noreply.github.com> Date: Tue, 13 Feb 2024 08:56:45 +0530 Subject: [PATCH] [Type annotations] fixed in save_model_card (#6948) fixed type annotations Co-authored-by: Sayak Paul --- examples/text_to_image/train_text_to_image.py | 4 ++-- examples/text_to_image/train_text_to_image_lora.py | 4 +++- examples/text_to_image/train_text_to_image_sdxl.py | 12 ++++++------ 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 6cc8db6fb2..6fb8b17944 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -67,8 +67,8 @@ DATASET_NAME_MAPPING = { def save_model_card( args, repo_id: str, - images=None, - repo_folder=None, + images: list = None, + repo_folder: str = None, ): img_str = "" if len(images) > 0: diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 73d0470522..47e67f695b 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -56,7 +56,9 @@ check_min_version("0.27.0.dev0") logger = get_logger(__name__, log_level="INFO") -def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): +def save_model_card( + repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None +): img_str = "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 8a5948350d..292e52bca0 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -66,12 +66,12 @@ DATASET_NAME_MAPPING = { def save_model_card( repo_id: str, - images=None, - validation_prompt=None, - base_model=str, - dataset_name=str, - repo_folder=None, - vae_path=None, + images: list = None, + validation_prompt: str = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, + vae_path: str = None, ): img_str = "" for i, image in enumerate(images):