mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Framework-agnostic timestep broadcasting
This commit is contained in:
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel
|
||||
@@ -71,7 +71,7 @@ def main(args):
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4)
|
||||
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo = init_git_repo(args, at_init=True)
|
||||
@@ -133,7 +133,7 @@ def main(args):
|
||||
# Generate a sample image for visual inspection
|
||||
if accelerator.is_main_process:
|
||||
with torch.no_grad():
|
||||
pipeline = DDPM(
|
||||
pipeline = DDPMPipeline(
|
||||
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
|
||||
)
|
||||
|
||||
@@ -172,6 +172,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
|
||||
parser.add_argument("--ema_power", type=float, default=3/4)
|
||||
parser.add_argument("--ema_max_decay", type=float, default=0.999)
|
||||
parser.add_argument("--push_to_hub", action="store_true")
|
||||
parser.add_argument("--hub_token", type=str, default=None)
|
||||
parser.add_argument("--hub_model_id", type=str, default=None)
|
||||
|
||||
@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
return pred_prev_sample
|
||||
|
||||
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
|
||||
if timesteps.dim() != 1:
|
||||
raise ValueError("`timesteps` must be a 1D tensor")
|
||||
|
||||
device = original_samples.device
|
||||
batch_size = original_samples.shape[0]
|
||||
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
|
||||
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
@@ -50,3 +52,29 @@ class SchedulerMixin:
|
||||
return torch.log(tensor)
|
||||
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def match_shape(
|
||||
self,
|
||||
values: Union[np.ndarray, torch.Tensor],
|
||||
broadcast_array: Union[np.ndarray, torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
|
||||
|
||||
Args:
|
||||
timesteps: an array or tensor of values to extract.
|
||||
broadcast_array: an array with a larger shape of K dimensions with the batch
|
||||
dimension equal to the length of timesteps.
|
||||
Returns:
|
||||
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
||||
"""
|
||||
|
||||
tensor_format = getattr(self, "tensor_format", "pt")
|
||||
values = values.flatten()
|
||||
|
||||
while len(values.shape) < len(broadcast_array.shape):
|
||||
values = values[..., None]
|
||||
if tensor_format == "pt":
|
||||
values = values.to(broadcast_array.device)
|
||||
|
||||
return values
|
||||
|
||||
Reference in New Issue
Block a user