From fbcc383340bfd6376ba91cf941d43d95596927fe Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 27 Oct 2022 15:16:59 +0200 Subject: [PATCH] Deprecate `init_git_repo`, refactor `train_unconditional.py` (#1022) Deprecate `init_git_repo` and `push_to_hub`, refactor `train_unconditional.py` --- .../train_unconditional.py | 228 +++++++++++++----- src/diffusers/hub_utils.py | 13 +- 2 files changed, 186 insertions(+), 55 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index f9f8c85bd1..6a1495b34c 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,6 +1,8 @@ import argparse import math import os +from pathlib import Path +from typing import Optional import torch import torch.nn.functional as F @@ -9,9 +11,9 @@ from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel -from diffusers.hub_utils import init_git_repo from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel +from huggingface_hub import HfFolder, Repository, whoami from torchvision.transforms import ( CenterCrop, Compose, @@ -27,6 +29,160 @@ from tqdm.auto import tqdm logger = get_logger(__name__) +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that HF Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="ddpm-model-64", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument( + "--resolution", + type=int, + default=64, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." + ) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") + parser.add_argument( + "--save_model_epochs", type=int, default=10, help="How often to save the model during training." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="cosine", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.95, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument( + "--adam_weight_decay", type=float, default=1e-6, help="Weight decay magnitude for the Adam optimizer." + ) + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer.") + parser.add_argument( + "--use_ema", + action="store_true", + default=True, + help="Whether to use Exponential Moving Average for the final model weights.", + ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") + parser.add_argument("--ema_power", type=float, default=3 / 4, help="The power value for the EMA decay.") + parser.add_argument("--ema_max_decay", type=float, default=0.9999, help="The maximum decay magnitude for EMA.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--hub_private_repo", action="store_true", help="Whether or not to create a private repository." + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("You must specify either a dataset name from the hub or a train data directory.") + + return args + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) accelerator = Accelerator( @@ -110,8 +266,22 @@ def main(args): ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) - if args.push_to_hub: - repo = init_git_repo(args, at_init=True) + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) if accelerator.is_main_process: run = os.path.split(__file__)[-1].split(".")[0] @@ -193,55 +363,5 @@ def main(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset_name", type=str, default=None) - parser.add_argument("--dataset_config_name", type=str, default=None) - parser.add_argument("--train_data_dir", type=str, default=None, help="A folder containing the training data.") - parser.add_argument("--output_dir", type=str, default="ddpm-model-64") - parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--cache_dir", type=str, default=None) - parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--train_batch_size", type=int, default=16) - parser.add_argument("--eval_batch_size", type=int, default=16) - parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--save_images_epochs", type=int, default=10) - parser.add_argument("--save_model_epochs", type=int, default=10) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) - parser.add_argument("--learning_rate", type=float, default=1e-4) - parser.add_argument("--lr_scheduler", type=str, default="cosine") - parser.add_argument("--lr_warmup_steps", type=int, default=500) - parser.add_argument("--adam_beta1", type=float, default=0.95) - parser.add_argument("--adam_beta2", type=float, default=0.999) - parser.add_argument("--adam_weight_decay", type=float, default=1e-6) - parser.add_argument("--adam_epsilon", type=float, default=1e-08) - parser.add_argument("--use_ema", action="store_true", default=True) - parser.add_argument("--ema_inv_gamma", type=float, default=1.0) - parser.add_argument("--ema_power", type=float, default=3 / 4) - parser.add_argument("--ema_max_decay", type=float, default=0.9999) - parser.add_argument("--push_to_hub", action="store_true") - parser.add_argument("--hub_token", type=str, default=None) - parser.add_argument("--hub_model_id", type=str, default=None) - parser.add_argument("--hub_private_repo", action="store_true") - parser.add_argument("--logging_dir", type=str, default="logs") - parser.add_argument( - "--mixed_precision", - type=str, - default="no", - choices=["no", "fp16", "bf16"], - help=( - "Whether to use mixed precision. Choose" - "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." - "and an Nvidia Ampere GPU." - ), - ) - - args = parser.parse_args() - env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) - if env_local_rank != -1 and env_local_rank != args.local_rank: - args.local_rank = env_local_rank - - if args.dataset_name is None and args.train_data_dir is None: - raise ValueError("You must specify either a dataset name from the hub or a train data directory.") - + args = parse_args() main(args) diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index c07329e36f..1f8cc0db0f 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -22,7 +22,7 @@ from typing import Optional from huggingface_hub import HfFolder, Repository, whoami from .pipeline_utils import DiffusionPipeline -from .utils import is_modelcards_available, logging +from .utils import deprecate, is_modelcards_available, logging if is_modelcards_available(): @@ -53,6 +53,12 @@ def init_git_repo(args, at_init: bool = False): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ + deprecation_message = ( + "Please use `huggingface_hub.Repository`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("init_git_repo()", "0.10.0", deprecation_message) + if hasattr(args, "local_rank") and args.local_rank not in [-1, 0]: return hub_token = args.hub_token if hasattr(args, "hub_token") else None @@ -114,6 +120,11 @@ def push_to_hub( The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the commit and an object to track the progress of the commit if `blocking=True` """ + deprecation_message = ( + "Please use `huggingface_hub.Repository` and `Repository.push_to_hub()`. " + "See `examples/unconditional_image_generation/train_unconditional.py` for an example." + ) + deprecate("push_to_hub()", "0.10.0", deprecation_message) if not hasattr(args, "hub_model_id") or args.hub_model_id is None: model_name = Path(args.output_dir).name