1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Framework-agnostic timestep broadcasting

This commit is contained in:
anton-l
2022-06-27 17:11:01 +02:00
parent 0e13d3293c
commit 1cf7933ea2
3 changed files with 38 additions and 11 deletions

View File

@@ -7,7 +7,7 @@ import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
@@ -71,7 +71,7 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
if accelerator.is_main_process:
with torch.no_grad():
pipeline = DDPM(
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
)
@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3/4)
parser.add_argument("--ema_max_decay", type=float, default=0.999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)

View File

@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
if timesteps.dim() != 1:
raise ValueError("`timesteps` must be a 1D tensor")
device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):

View File

@@ -14,6 +14,8 @@
import numpy as np
import torch
from typing import Union
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -50,3 +52,29 @@ class SchedulerMixin:
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(
self,
values: Union[np.ndarray, torch.Tensor],
broadcast_array: Union[np.ndarray, torch.Tensor]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values