mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update code to reflect latest changes as of May 30th (#3616)
* update code to reflect latest changes as of May 30th * update text to image example * reflect changes to textual inversion * make style * fix typo * Revert unnecessary readme changes --------- Co-authored-by: root <root@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net> Co-authored-by: Prathik Rao <prathikrao@microsoft.com@orttrainingdev8.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
@@ -20,6 +20,7 @@ import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -28,30 +29,96 @@ import torch.utils.checkpoint
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
check_min_version("0.17.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=accelerator.unwrap_model(vae),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
|
||||
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
elif tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_pertubation", type=float, default=0, help="The scale of input pretubation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
@@ -110,6 +177,13 @@ def parse_args():
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompts",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
@@ -191,6 +265,13 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
@@ -295,6 +376,22 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run validation every X epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
default="text2image-fine-tune",
|
||||
help=(
|
||||
"The `project_name` argument passed to Accelerator.init_trackers for"
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -312,13 +409,18 @@ def parse_args():
|
||||
return args
|
||||
|
||||
|
||||
dataset_name_mapping = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.non_ema_revision is not None:
|
||||
deprecate(
|
||||
"non_ema_revision!=None",
|
||||
"0.15.0",
|
||||
message=(
|
||||
"Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
|
||||
" use `--variant=non_ema` instead."
|
||||
),
|
||||
)
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
|
||||
@@ -366,10 +468,34 @@ def main():
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
||||
)
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
|
||||
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
|
||||
# will try to assign the same optimizer with the same weights to all models during
|
||||
# `deepspeed.initialize`, which of course doesn't work.
|
||||
#
|
||||
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
|
||||
# frozen models from being partitioned during `zero.Init` which gets called during
|
||||
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
|
||||
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
||||
)
|
||||
@@ -383,17 +509,81 @@ def main():
|
||||
ema_unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
ema_unet = EMAModel(ema_unet.parameters())
|
||||
ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
def compute_snr(timesteps):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR.
|
||||
snr = (alpha / sigma) ** 2
|
||||
return snr
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
|
||||
ema_unet.load_state_dict(load_model.state_dict())
|
||||
ema_unet.to(accelerator.device)
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
vae.enable_gradient_checkpointing()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
@@ -426,6 +616,8 @@ def main():
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
optimizer = ORT_FP16_Optimizer(optimizer)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
@@ -455,7 +647,7 @@ def main():
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
@@ -549,10 +741,10 @@ def main():
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
unet = ORTModule(unet)
|
||||
|
||||
if args.use_ema:
|
||||
accelerator.register_for_checkpointing(ema_unet)
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
unet = ORTModule(unet)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
@@ -565,8 +757,6 @@ def main():
|
||||
# Move text_encode and vae to gpu and cast to weight_dtype
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if args.use_ema:
|
||||
ema_unet.to(accelerator.device)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -578,7 +768,9 @@ def main():
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
|
||||
tracker_config = dict(vars(args))
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
@@ -639,6 +831,13 @@ def main():
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
||||
)
|
||||
if args.input_pertubation:
|
||||
new_noise = noise + args.input_pertubation * torch.randn_like(noise)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
@@ -646,7 +845,10 @@ def main():
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
if args.input_pertubation:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
|
||||
else:
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
@@ -660,8 +862,24 @@ def main():
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
@@ -696,6 +914,26 @@ def main():
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
log_validation(
|
||||
vae,
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
unet,
|
||||
args,
|
||||
accelerator,
|
||||
weight_dtype,
|
||||
global_step,
|
||||
)
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_unet.restore(unet.parameters())
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
|
||||
@@ -53,7 +53,19 @@ If you have already cloned the repo, then you won't need to go through these ste
|
||||
|
||||
<br>
|
||||
|
||||
Now let's get our dataset.Download 3-4 images from [here](https://drive.google.com/drive/folders/1fmJMs25nxS_rSNqS5hTcRdLem_YQXbq5) and save them in a directory. This will be our training data.
|
||||
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```py
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./cat"
|
||||
snapshot_download("diffusers/cat_toy_example", local_dir=local_dir, repo_type="dataset", ignore_patterns=".gitattributes")
|
||||
```
|
||||
|
||||
This will be our training data.
|
||||
Now we can launch the training using
|
||||
|
||||
## Use ONNXRuntime to accelerate training
|
||||
In order to leverage onnxruntime to accelerate training, please use textual_inversion.py
|
||||
|
||||
@@ -18,9 +18,9 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
@@ -31,6 +31,7 @@ from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
|
||||
# TODO: remove and import from diffusers.utils when the new version of diffusers is released
|
||||
@@ -55,6 +56,9 @@ from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
|
||||
PIL_INTERPOLATION = {
|
||||
"linear": PIL.Image.Resampling.BILINEAR,
|
||||
@@ -75,14 +79,92 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
check_min_version("0.17.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, repo_folder=None):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: creativeml-openrail-m
|
||||
base_model: {base_model}
|
||||
tags:
|
||||
- stable-diffusion
|
||||
- stable-diffusion-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- textual_inversion
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Textual inversion text2image fine-tuning - {repo_id}
|
||||
These are textual inversion adaption weights for {base_model}. You can find some example images in the following. \n
|
||||
{img_str}
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
vae=vae,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
return images
|
||||
|
||||
|
||||
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
|
||||
logger.info("Saving embeddings")
|
||||
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
|
||||
learned_embeds = (
|
||||
accelerator.unwrap_model(text_encoder)
|
||||
.get_input_embeddings()
|
||||
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
|
||||
)
|
||||
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
|
||||
torch.save(learned_embeds_dict, save_path)
|
||||
|
||||
@@ -96,10 +178,15 @@ def parse_args():
|
||||
help="Save learned_embeds.bin every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only_save_embeds",
|
||||
"--save_as_full_pipeline",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Save only the embeddings for the new concept.",
|
||||
help="Save the complete stable diffusion pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_vectors",
|
||||
type=int,
|
||||
default=1,
|
||||
help="How many textual inversion vectors shall be used to learn the concept.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
@@ -269,12 +356,22 @@ def parse_args():
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_steps",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Run validation every X steps. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
default=None,
|
||||
help=(
|
||||
"Run validation every X epochs. Validation consists of running the prompt"
|
||||
"Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`"
|
||||
" and logging the images."
|
||||
),
|
||||
@@ -479,7 +576,6 @@ def main():
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
@@ -489,11 +585,9 @@ def main():
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
@@ -528,8 +622,19 @@ def main():
|
||||
)
|
||||
|
||||
# Add the placeholder token in tokenizer
|
||||
num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
|
||||
if num_added_tokens == 0:
|
||||
placeholder_tokens = [args.placeholder_token]
|
||||
|
||||
if args.num_vectors < 1:
|
||||
raise ValueError(f"--num_vectors has to be larger or equal to 1, but is {args.num_vectors}")
|
||||
|
||||
# add dummy tokens for multi-vector
|
||||
additional_tokens = []
|
||||
for i in range(1, args.num_vectors):
|
||||
additional_tokens.append(f"{args.placeholder_token}_{i}")
|
||||
placeholder_tokens += additional_tokens
|
||||
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != args.num_vectors:
|
||||
raise ValueError(
|
||||
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
|
||||
" `placeholder_token` that is not already in the tokenizer."
|
||||
@@ -542,14 +647,16 @@ def main():
|
||||
raise ValueError("The initializer token must be a single token.")
|
||||
|
||||
initializer_token_id = token_ids[0]
|
||||
placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
|
||||
placeholder_token_ids = tokenizer.convert_tokens_to_ids(placeholder_tokens)
|
||||
|
||||
# Resize the token embeddings as we are adding new special tokens to the tokenizer
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
# Initialise the newly added placeholder token with the embeddings of the initializer token
|
||||
token_embeds = text_encoder.get_input_embeddings().weight.data
|
||||
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
||||
with torch.no_grad():
|
||||
for token_id in placeholder_token_ids:
|
||||
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
|
||||
|
||||
# Freeze vae and unet
|
||||
vae.requires_grad_(False)
|
||||
@@ -568,6 +675,13 @@ def main():
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
@@ -591,6 +705,8 @@ def main():
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
optimizer = ORT_FP16_Optimizer(optimizer)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = TextualInversionDataset(
|
||||
data_root=args.train_data_dir,
|
||||
@@ -605,6 +721,15 @@ def main():
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
if args.validation_epochs is not None:
|
||||
warnings.warn(
|
||||
f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
|
||||
" Deprecated validation_epochs in favor of `validation_steps`"
|
||||
f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
|
||||
FutureWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
args.validation_steps = args.validation_epochs * len(train_dataset)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -626,6 +751,8 @@ def main():
|
||||
)
|
||||
|
||||
text_encoder = ORTModule(text_encoder)
|
||||
unet = ORTModule(unet)
|
||||
vae = ORTModule(vae)
|
||||
|
||||
# For mixed precision training we cast the unet and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
@@ -663,7 +790,6 @@ def main():
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
@@ -744,7 +870,9 @@ def main():
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
|
||||
|
||||
with torch.no_grad():
|
||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
@@ -752,72 +880,38 @@ def main():
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
images = []
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
if global_step % args.save_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
|
||||
images = log_validation(
|
||||
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch
|
||||
)
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline (note: unet and vae are loaded again in float32)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = (
|
||||
None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
)
|
||||
prompt = args.num_validation_images * [args.validation_prompt]
|
||||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub and args.only_save_embeds:
|
||||
if args.push_to_hub and not args.save_as_full_pipeline:
|
||||
logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
|
||||
save_full_model = True
|
||||
else:
|
||||
save_full_model = not args.only_save_embeds
|
||||
save_full_model = args.save_as_full_pipeline
|
||||
if save_full_model:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
@@ -829,9 +923,15 @@ def main():
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
# Save the newly trained embeddings
|
||||
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
|
||||
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
|
||||
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(
|
||||
repo_id,
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
repo_folder=args.output_dir,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
|
||||
@@ -34,7 +34,7 @@ In order to leverage onnxruntime to accelerate training, please use train_uncond
|
||||
The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxruntime:
|
||||
|
||||
```bash
|
||||
accelerate launch train_unconditional_ort.py \
|
||||
accelerate launch train_unconditional.py \
|
||||
--dataset_name="huggan/flowers-102-categories" \
|
||||
--resolution=64 --center_crop --random_flip \
|
||||
--output_dir="ddpm-ema-flowers-64" \
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
datasets
|
||||
tensorboard
|
||||
@@ -6,6 +6,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -14,7 +15,9 @@ from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
|
||||
from onnxruntime.training.ortmodule import ORTModule
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
@@ -22,11 +25,12 @@ import diffusers
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
|
||||
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.13.0.dev0")
|
||||
check_min_version("0.17.0.dev0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
@@ -34,6 +38,7 @@ logger = get_logger(__name__, log_level="INFO")
|
||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||
"""
|
||||
Extract values from a 1-D numpy array for a batch of indices.
|
||||
|
||||
:param arr: the 1-D numpy array.
|
||||
:param timesteps: a tensor of indices into the array to extract.
|
||||
:param broadcast_shape: a larger shape of K dimensions with the batch
|
||||
@@ -66,6 +71,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_config_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the UNet model to train, leave as None to use standard DDPM configuration.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
@@ -251,6 +262,9 @@ def parse_args():
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
@@ -295,6 +309,40 @@ def main(args):
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_model.save_pretrained(os.path.join(output_dir, "unet_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "unet"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DModel)
|
||||
ema_model.load_state_dict(load_model.state_dict())
|
||||
ema_model.to(accelerator.device)
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = UNet2DModel.from_pretrained(input_dir, subfolder="unet")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -328,29 +376,33 @@ def main(args):
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Initialize the model
|
||||
model = UNet2DModel(
|
||||
sample_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
layers_per_block=2,
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_block_types=(
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types=(
|
||||
"UpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
if args.model_config_name_or_path is None:
|
||||
model = UNet2DModel(
|
||||
sample_size=args.resolution,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
layers_per_block=2,
|
||||
block_out_channels=(128, 128, 256, 256, 512, 512),
|
||||
down_block_types=(
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"AttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
up_block_types=(
|
||||
"UpBlock2D",
|
||||
"AttnUpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
)
|
||||
else:
|
||||
config = UNet2DModel.load_config(args.model_config_name_or_path)
|
||||
model = UNet2DModel.from_config(config)
|
||||
|
||||
# Create EMA for the model.
|
||||
if args.use_ema:
|
||||
@@ -360,8 +412,23 @@ def main(args):
|
||||
use_ema_warmup=True,
|
||||
inv_gamma=args.ema_inv_gamma,
|
||||
power=args.ema_power,
|
||||
model_cls=UNet2DModel,
|
||||
model_config=model.config,
|
||||
)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Initialize the scheduler
|
||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||
if accepts_prediction_type:
|
||||
@@ -382,6 +449,8 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
optimizer = ORT_FP16_Optimizer(optimizer)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
@@ -434,10 +503,7 @@ def main(args):
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
model = ORTModule(model)
|
||||
|
||||
if args.use_ema:
|
||||
accelerator.register_for_checkpointing(ema_model)
|
||||
ema_model.to(accelerator.device)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
@@ -446,6 +512,8 @@ def main(args):
|
||||
run = os.path.split(__file__)[-1].split(".")[0]
|
||||
accelerator.init_trackers(run)
|
||||
|
||||
model = ORTModule(model)
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
max_train_steps = args.num_epochs * num_update_steps_per_epoch
|
||||
@@ -552,7 +620,7 @@ def main(args):
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
|
||||
if args.use_ema:
|
||||
logs["ema_decay"] = ema_model.decay
|
||||
logs["ema_decay"] = ema_model.cur_decay_value
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
progress_bar.close()
|
||||
@@ -563,8 +631,11 @@ def main(args):
|
||||
if accelerator.is_main_process:
|
||||
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
unet = accelerator.unwrap_model(model)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.store(unet.parameters())
|
||||
ema_model.copy_to(unet.parameters())
|
||||
|
||||
pipeline = DDPMPipeline(
|
||||
unet=unet,
|
||||
scheduler=noise_scheduler,
|
||||
@@ -575,18 +646,24 @@ def main(args):
|
||||
images = pipeline(
|
||||
generator=generator,
|
||||
batch_size=args.eval_batch_size,
|
||||
output_type="numpy",
|
||||
num_inference_steps=args.ddpm_num_inference_steps,
|
||||
output_type="numpy",
|
||||
).images
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
||||
if args.logger == "tensorboard":
|
||||
accelerator.get_tracker("tensorboard").add_images(
|
||||
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
|
||||
)
|
||||
if is_accelerate_version(">=", "0.17.0.dev0"):
|
||||
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
|
||||
else:
|
||||
tracker = accelerator.get_tracker("tensorboard")
|
||||
tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
|
||||
elif args.logger == "wandb":
|
||||
# Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
|
||||
accelerator.get_tracker("wandb").log(
|
||||
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
|
||||
step=global_step,
|
||||
@@ -594,7 +671,22 @@ def main(args):
|
||||
|
||||
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
|
||||
# save the model
|
||||
unet = accelerator.unwrap_model(model)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.store(unet.parameters())
|
||||
ema_model.copy_to(unet.parameters())
|
||||
|
||||
pipeline = DDPMPipeline(
|
||||
unet=unet,
|
||||
scheduler=noise_scheduler,
|
||||
)
|
||||
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if args.use_ema:
|
||||
ema_model.restore(unet.parameters())
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user