1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Patrick von Platen
2022-06-27 15:34:47 +00:00
12 changed files with 351 additions and 136 deletions

View File

@@ -0,0 +1,201 @@
import argparse
import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPMScheduler, Glide, GlideUNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
pipeline = Glide.from_pretrained("fusing/glide-base")
model = pipeline.text_unet
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
augmentations = Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
]
)
dataset = load_dataset(args.dataset, split="train")
text_encoder = pipeline.text_encoder.eval()
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
text_inputs = text_inputs.input_ids.to(accelerator.device)
with torch.no_grad():
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state
return {"images": images, "text_embeddings": text_embeddings}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, text_encoder, optimizer, train_dataloader, lr_scheduler
)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
# Train!
is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size() if is_distributed else 1
total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size
max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataloader.dataset)}")
logger.info(f" Num Epochs = {args.num_epochs}")
logger.info(f" Instantaneous batch size per device = {args.batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["images"]
batch_size, n_channels, height, width = clean_images.shape
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
timesteps = torch.randint(
0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device
).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
# predict the noise residual
loss = F.mse_loss(model_output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
optimizer.step()
else:
model_output = model(noisy_images, timesteps, batch["text_embeddings"])
model_output, model_var_values = torch.split(model_output, n_channels, dim=1)
# Learn the variance using the variational bound, but don't let
# it affect our mean prediction.
frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1)
# predict the noise residual
loss = F.mse_loss(model_output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
accelerator.wait_for_everyone()
# Generate a sample image for visual inspection
if accelerator.is_main_process:
model.eval()
with torch.no_grad():
pipeline.unet = accelerator.unwrap_model(model)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50)
# process image to PIL
image_processed = image.squeeze(0)
image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
image_pil = PIL.Image.fromarray(image_processed)
# save image
test_dir = os.path.join(args.output_dir, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch:04d}.png")
# save the model
if args.push_to_hub:
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
accelerator.wait_for_everyone()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
parser.add_argument("--output_dir", type=str, default="glide-text2image")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)
parser.add_argument("--hub_private_repo", action="store_true")
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU."
),
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
main(args)

View File

@@ -4,19 +4,19 @@ import os
import torch
import torch.nn.functional as F
import bitsandbytes as bnb
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetLDMModel
from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler
from diffusers.utils import logging
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Lambda,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
@@ -30,6 +30,8 @@ logger = logging.get_logger(__name__)
def main(args):
accelerator = Accelerator(mixed_precision=args.mixed_precision)
pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large")
pipeline.unet = None # this model will be trained from scratch now
model = UNetLDMModel(
attention_resolutions=[4, 2, 1],
channel_mult=[1, 2, 4, 4],
@@ -37,7 +39,7 @@ def main(args):
conv_resample=True,
dims=2,
dropout=0,
image_size=32,
image_size=8,
in_channels=4,
model_channels=320,
num_heads=8,
@@ -51,7 +53,7 @@ def main(args):
legacy=False,
)
noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt")
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr)
augmentations = Compose(
[
@@ -59,14 +61,22 @@ def main(args):
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
Normalize([0.5], [0.5]),
]
)
dataset = load_dataset(args.dataset, split="train")
text_encoder = pipeline.bert.eval()
vqvae = pipeline.vqvae.eval()
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt")
with torch.no_grad():
text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state
images = 1 / 0.18215 * torch.stack(images, dim=0)
latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode()
return {"images": images, "text_embeddings": text_embeddings, "latents": latents}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
@@ -78,9 +88,11 @@ def main(args):
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler
)
text_encoder = text_encoder.cpu()
vqvae = vqvae.cpu()
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
@@ -98,29 +110,31 @@ def main(args):
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {max_steps}")
global_step = 0
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noise_samples = torch.randn(clean_images.shape).to(clean_images.device)
bsz = clean_images.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long()
clean_latents = batch["latents"]
noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device)
bsz = clean_latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long()
# add noise onto the clean images according to the noise magnitude at each timestep
# add noise onto the clean latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps)
noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps)
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
optimizer.step()
else:
output = model(noisy_images, timesteps)
output = model(noisy_latents, timesteps, context=batch["text_embeddings"])
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
loss = loss / args.gradient_accumulation_steps
@@ -131,24 +145,25 @@ def main(args):
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
global_step += 1
optimizer.step()
if is_distributed:
torch.distributed.barrier()
accelerator.wait_for_everyone()
# Generate a sample image for visual inspection
if args.local_rank in [-1, 0]:
if accelerator.is_main_process:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler)
pipeline.unet = accelerator.unwrap_model(model)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
image = pipeline(
["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50
)
# process image to PIL
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed * 255.0
image_processed = image_processed.type(torch.uint8).numpy()
image_pil = PIL.Image.fromarray(image_processed[0])
@@ -162,20 +177,19 @@ def main(args):
push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False)
else:
pipeline.save_pretrained(args.output_dir)
if is_distributed:
torch.distributed.barrier()
accelerator.wait_for_everyone()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--output_dir", type=str, default="ddpm-model")
parser.add_argument("--dataset", type=str, default="fusing/dog_captions")
parser.add_argument("--output_dir", type=str, default="ldm-text2image")
parser.add_argument("--overwrite_output_dir", action="store_true")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--resolution", type=int, default=128)
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--push_to_hub", action="store_true")

View File

@@ -7,7 +7,7 @@ import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers import DDPMPipeline, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
@@ -71,7 +71,7 @@ def main(args):
model, optimizer, train_dataloader, lr_scheduler
)
ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4)
ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay)
if args.push_to_hub:
repo = init_git_repo(args, at_init=True)
@@ -133,7 +133,7 @@ def main(args):
# Generate a sample image for visual inspection
if accelerator.is_main_process:
with torch.no_grad():
pipeline = DDPM(
pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler
)
@@ -172,6 +172,9 @@ if __name__ == "__main__":
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
parser.add_argument("--ema_power", type=float, default=3/4)
parser.add_argument("--ema_max_decay", type=float, default=0.999)
parser.add_argument("--push_to_hub", action="store_true")
parser.add_argument("--hub_token", type=str, default=None)
parser.add_argument("--hub_model_id", type=str, default=None)

View File

@@ -64,7 +64,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, use_conv_transpose=False, dims=2, out_channels=None):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
self.use_conv_transpose = use_conv_transpose
if use_conv_transpose:
self.conv = conv_transpose_nd(dims, channels, out_channels, 4, 2, 1)
self.conv = conv_transpose_nd(dims, channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
@@ -207,6 +207,7 @@ class GradTTSUpsample(torch.nn.Module):
return self.conv(x)
# TODO (patil-suraj): needs test
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()

View File

@@ -31,6 +31,7 @@ from tqdm import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
def nonlinearity(x):
@@ -42,20 +43,6 @@ def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
@@ -259,7 +246,7 @@ class UNetModel(ModelMixin, ConfigMixin):
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
up.upsample = Upsample(block_in, use_conv=resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order

View File

@@ -8,6 +8,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
def convert_module_to_f16(l):
@@ -125,36 +126,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
@@ -231,8 +202,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
@@ -567,7 +538,7 @@ class GlideUNetModel(ModelMixin, ConfigMixin):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))

View File

@@ -3,6 +3,7 @@ import torch
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
class Mish(torch.nn.Module):
@@ -10,15 +11,6 @@ class Mish(torch.nn.Module):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Upsample(torch.nn.Module):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(torch.nn.Module):
def __init__(self, dim):
super(Downsample, self).__init__()
@@ -166,7 +158,7 @@ class UNetGradTTSModel(ModelMixin, ConfigMixin):
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in),
Upsample(dim_in, use_conv_transpose=True),
]
)
)

View File

@@ -10,6 +10,7 @@ import torch.nn.functional as F
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from .embeddings import get_timestep_embedding
from .resnet import Upsample
# try:
@@ -403,35 +404,6 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest")
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
@@ -506,8 +478,8 @@ class ResBlock(TimestepBlock):
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
self.h_upd = Upsample(channels, use_conv=False, dims=dims)
self.x_upd = Upsample(channels, use_conv=False, dims=dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
@@ -974,7 +946,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin):
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
else Upsample(ch, use_conv=conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))

View File

@@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return pred_prev_sample
def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor):
if timesteps.dim() != 1:
raise ValueError("`timesteps` must be a 1D tensor")
device = original_samples.device
batch_size = original_samples.shape[0]
timesteps = timesteps.reshape(batch_size, 1, 1, 1)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise
sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples)
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def __len__(self):

View File

@@ -14,6 +14,8 @@
import numpy as np
import torch
from typing import Union
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@@ -50,3 +52,29 @@ class SchedulerMixin:
return torch.log(tensor)
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def match_shape(
self,
values: Union[np.ndarray, torch.Tensor],
broadcast_array: Union[np.ndarray, torch.Tensor]
):
"""
Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims.
Args:
timesteps: an array or tensor of values to extract.
broadcast_array: an array with a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
Returns:
a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
tensor_format = getattr(self, "tensor_format", "pt")
values = values.flatten()
while len(values.shape) < len(broadcast_array.shape):
values = values[..., None]
if tensor_format == "pt":
values = values.to(broadcast_array.device)
return values

View File

@@ -22,6 +22,7 @@ import numpy as np
import torch
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Upsample
from diffusers.testing_utils import floats_tensor, slow, torch_device
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch.tensor([-0.9801, -0.9464, -0.9349, -0.3952, 0.8887, -0.9709, 0.5299, -0.2853, -0.9927]),
1e-3,
)
class UpsampleBlockTests(unittest.TestCase):
def test_upsample_default(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.7145, 1.3773, 0.3492, 0.8448, 1.0839, -0.3341, 0.5956, 0.1250, -0.4841])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_conv_out_dim(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=True, out_channels=64)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 64, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([0.2703, 0.1656, -0.2538, -0.0553, -0.2984, 0.1044, 0.1155, 0.2579, 0.7755])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_upsample_with_transpose(self):
torch.manual_seed(0)
sample = torch.randn(1, 32, 32, 32)
upsample = Upsample(channels=32, use_conv=False, use_conv_transpose=True)
with torch.no_grad():
upsampled = upsample(sample)
assert upsampled.shape == (1, 32, 64, 64)
output_slice = upsampled[0, -1, -3:, -3:]
expected_slice = torch.tensor([-0.3028, -0.1582, 0.0071, 0.0350, -0.4799, -0.1139, 0.1056, -0.1153, -0.1046])
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)

View File

@@ -21,7 +21,7 @@ import unittest
import numpy as np
import torch
from diffusers import (
from diffusers import ( # GradTTSPipeline,
BDDMPipeline,
DDIMPipeline,
DDIMScheduler,
@@ -30,7 +30,6 @@ from diffusers import (
GlidePipeline,
GlideSuperResUNetModel,
GlideTextToImageUNetModel,
GradTTSPipeline,
GradTTSScheduler,
LatentDiffusionPipeline,
NCSNpp,