From d726857f7e7f35aa8c1f3d031048ba6c7cb069f3 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 15:09:33 +0000 Subject: [PATCH 1/6] remove einops from unet_ldm --- src/diffusers/models/unet_ldm.py | 51 +++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 22664dd7f1..bd70913ff2 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -6,19 +6,18 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - - -try: - from einops import rearrange, repeat -except: - print("Einops is not installed") - pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding +#try: +# from einops import rearrange, repeat +#except: +# print("Einops is not installed") +# pass + + def exists(val): return val is not None @@ -153,7 +152,23 @@ class CrossAttention(nn.Module): self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + def forward(self, x, context=None, mask=None): + batch_size, sequence_length, dim = x.shape + h = self.heads q = self.to_q(x) @@ -161,21 +176,29 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) +# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + + q = self.reshape_heads_to_batch_dim(q) + k = self.reshape_heads_to_batch_dim(k) + v = self.reshape_heads_to_batch_dim(v) sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): - mask = rearrange(mask, "b ... -> b (...)") +# mask = rearrange(mask, "b ... -> b (...)") + maks = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, "b j -> (b h) () j", h=h) +# mask = repeat(mask, "b j -> (b h) () j", h=h) + mask = mask[:, None, :].repeat(h, 1, 1) +# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) out = torch.einsum("b i j, b j d -> b i d", attn, v) - out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + out = self.reshape_batch_dim_to_heads(out) +# out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) @@ -233,10 +256,10 @@ class SpatialTransformer(nn.Module): x_in = x x = self.norm(x) x = self.proj_in(x) - x = rearrange(x, "b c h w -> b (h w) c") + x = x.permute(0, 2, 3, 1).reshape(b, h * w, c) for block in self.transformer_blocks: x = block(x, context=context) - x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = x.reshape(b, h, w, c).permute(0, 3, 1, 2) x = self.proj_out(x) return x + x_in From 1cf7933ea234b9aa0ba5b13fbe60740fa855e838 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 27 Jun 2022 17:11:01 +0200 Subject: [PATCH 2/6] Framework-agnostic timestep broadcasting --- examples/train_unconditional.py | 9 ++++--- src/diffusers/schedulers/scheduling_ddpm.py | 12 +++------ src/diffusers/schedulers/scheduling_utils.py | 28 ++++++++++++++++++++ 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index 846dd3eda4..fe45f2a5fa 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import PIL.Image from accelerate import Accelerator from datasets import load_dataset -from diffusers import DDPM, DDPMScheduler, UNetModel +from diffusers import DDPMPipeline, DDPMScheduler, UNetModel from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel @@ -71,7 +71,7 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler ) - ema_model = EMAModel(model, inv_gamma=1.0, power=3 / 4) + ema_model = EMAModel(model, inv_gamma=args.ema_inv_gamma, power=args.ema_power, max_value=args.ema_max_decay) if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -133,7 +133,7 @@ def main(args): # Generate a sample image for visual inspection if accelerator.is_main_process: with torch.no_grad(): - pipeline = DDPM( + pipeline = DDPMPipeline( unet=accelerator.unwrap_model(ema_model.averaged_model), noise_scheduler=noise_scheduler ) @@ -172,6 +172,9 @@ if __name__ == "__main__": parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=3/4) + parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) parser.add_argument("--hub_model_id", type=str, default=None) diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 5dea0b22b3..d908850dfe 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -144,16 +144,12 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): return pred_prev_sample def training_step(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor): - if timesteps.dim() != 1: - raise ValueError("`timesteps` must be a 1D tensor") - - device = original_samples.device - batch_size = original_samples.shape[0] - timesteps = timesteps.reshape(batch_size, 1, 1, 1) - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - noisy_samples = sqrt_alpha_prod.to(device) * original_samples + sqrt_one_minus_alpha_prod.to(device) * noise + sqrt_one_minus_alpha_prod = self.match_shape(sqrt_one_minus_alpha_prod, original_samples) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise return noisy_samples def __len__(self): diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index a6f317852d..4cfbc5e59d 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,6 +14,8 @@ import numpy as np import torch +from typing import Union + SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -50,3 +52,29 @@ class SchedulerMixin: return torch.log(tensor) raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") + + def match_shape( + self, + values: Union[np.ndarray, torch.Tensor], + broadcast_array: Union[np.ndarray, torch.Tensor] + ): + """ + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + + Args: + timesteps: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + + tensor_format = getattr(self, "tensor_format", "pt") + values = values.flatten() + + while len(values.shape) < len(broadcast_array.shape): + values = values[..., None] + if tensor_format == "pt": + values = values.to(broadcast_array.device) + + return values From 07ff0abff4484aad441ceb64c11e60887aac4522 Mon Sep 17 00:00:00 2001 From: anton-l Date: Mon, 27 Jun 2022 17:25:59 +0200 Subject: [PATCH 3/6] Glide and LDM training experiments --- .../experimental/train_glide_text_to_image.py | 201 ++++++++++++++++++ examples/train_latent_text_to_image.py | 76 ++++--- 2 files changed, 246 insertions(+), 31 deletions(-) create mode 100644 examples/experimental/train_glide_text_to_image.py diff --git a/examples/experimental/train_glide_text_to_image.py b/examples/experimental/train_glide_text_to_image.py new file mode 100644 index 0000000000..9b1f28d680 --- /dev/null +++ b/examples/experimental/train_glide_text_to_image.py @@ -0,0 +1,201 @@ +import argparse +import os + +import torch +import torch.nn.functional as F + +import bitsandbytes as bnb +import PIL.Image +from accelerate import Accelerator +from datasets import load_dataset +from diffusers import DDPMScheduler, Glide, GlideUNetModel +from diffusers.hub_utils import init_git_repo, push_to_hub +from diffusers.optimization import get_scheduler +from diffusers.utils import logging +from torchvision.transforms import ( + CenterCrop, + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + Resize, + ToTensor, +) +from tqdm.auto import tqdm + + +logger = logging.get_logger(__name__) + + +def main(args): + accelerator = Accelerator(mixed_precision=args.mixed_precision) + + pipeline = Glide.from_pretrained("fusing/glide-base") + model = pipeline.text_unet + noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") + optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) + + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) + dataset = load_dataset(args.dataset, split="train") + + text_encoder = pipeline.text_encoder.eval() + + def transforms(examples): + images = [augmentations(image.convert("RGB")) for image in examples["image"]] + text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") + text_inputs = text_inputs.input_ids.to(accelerator.device) + with torch.no_grad(): + text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs).last_hidden_state + return {"images": images, "text_embeddings": text_embeddings} + + dataset.set_transform(transforms) + train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) + + lr_scheduler = get_scheduler( + "linear", + optimizer=optimizer, + num_warmup_steps=args.warmup_steps, + num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, + ) + + model, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + if args.push_to_hub: + repo = init_git_repo(args, at_init=True) + + # Train! + is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized() + world_size = torch.distributed.get_world_size() if is_distributed else 1 + total_train_batch_size = args.batch_size * args.gradient_accumulation_steps * world_size + max_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_epochs + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataloader.dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + + for epoch in range(args.num_epochs): + model.train() + with tqdm(total=len(train_dataloader), unit="ba") as pbar: + pbar.set_description(f"Epoch {epoch}") + for step, batch in enumerate(train_dataloader): + clean_images = batch["images"] + batch_size, n_channels, height, width = clean_images.shape + noise_samples = torch.randn(clean_images.shape).to(clean_images.device) + timesteps = torch.randint( + 0, noise_scheduler.timesteps, (batch_size,), device=clean_images.device + ).long() + + # add noise onto the clean images according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) + + if step % args.gradient_accumulation_steps != 0: + with accelerator.no_sync(model): + model_output = model(noisy_images, timesteps, batch["text_embeddings"]) + model_output, model_var_values = torch.split(model_output, n_channels, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + + # predict the noise residual + loss = F.mse_loss(model_output, noise_samples) + + loss = loss / args.gradient_accumulation_steps + + accelerator.backward(loss) + optimizer.step() + else: + model_output = model(noisy_images, timesteps, batch["text_embeddings"]) + model_output, model_var_values = torch.split(model_output, n_channels, dim=1) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = torch.cat([model_output.detach(), model_var_values], dim=1) + + # predict the noise residual + loss = F.mse_loss(model_output, noise_samples) + loss = loss / args.gradient_accumulation_steps + accelerator.backward(loss) + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + pbar.update(1) + pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + + accelerator.wait_for_everyone() + + # Generate a sample image for visual inspection + if accelerator.is_main_process: + model.eval() + with torch.no_grad(): + pipeline.unet = accelerator.unwrap_model(model) + + generator = torch.manual_seed(0) + # run pipeline in inference (sample random noise and denoise) + image = pipeline("a clip art of a corgi", generator=generator, num_upscale_inference_steps=50) + + # process image to PIL + image_processed = image.squeeze(0) + image_processed = ((image_processed + 1) * 127.5).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + image_pil = PIL.Image.fromarray(image_processed) + + # save image + test_dir = os.path.join(args.output_dir, "test_samples") + os.makedirs(test_dir, exist_ok=True) + image_pil.save(f"{test_dir}/{epoch:04d}.png") + + # save the model + if args.push_to_hub: + push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) + else: + pipeline.save_pretrained(args.output_dir) + accelerator.wait_for_everyone() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument("--local_rank", type=int, default=-1) + parser.add_argument("--dataset", type=str, default="fusing/dog_captions") + parser.add_argument("--output_dir", type=str, default="glide-text2image") + parser.add_argument("--overwrite_output_dir", action="store_true") + parser.add_argument("--resolution", type=int, default=64) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--num_epochs", type=int, default=100) + parser.add_argument("--gradient_accumulation_steps", type=int, default=4) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument("--warmup_steps", type=int, default=500) + parser.add_argument("--push_to_hub", action="store_true") + parser.add_argument("--hub_token", type=str, default=None) + parser.add_argument("--hub_model_id", type=str, default=None) + parser.add_argument("--hub_private_repo", action="store_true") + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + main(args) diff --git a/examples/train_latent_text_to_image.py b/examples/train_latent_text_to_image.py index fd823fdad9..7cbfa2c49d 100644 --- a/examples/train_latent_text_to_image.py +++ b/examples/train_latent_text_to_image.py @@ -4,19 +4,19 @@ import os import torch import torch.nn.functional as F +import bitsandbytes as bnb import PIL.Image from accelerate import Accelerator from datasets import load_dataset -from diffusers import DDPM, DDPMScheduler, UNetLDMModel +from diffusers import DDPMScheduler, LatentDiffusion, UNetLDMModel from diffusers.hub_utils import init_git_repo, push_to_hub -from diffusers.modeling_utils import unwrap_model from diffusers.optimization import get_scheduler from diffusers.utils import logging from torchvision.transforms import ( CenterCrop, Compose, InterpolationMode, - Lambda, + Normalize, RandomHorizontalFlip, Resize, ToTensor, @@ -30,6 +30,8 @@ logger = logging.get_logger(__name__) def main(args): accelerator = Accelerator(mixed_precision=args.mixed_precision) + pipeline = LatentDiffusion.from_pretrained("fusing/latent-diffusion-text2im-large") + pipeline.unet = None # this model will be trained from scratch now model = UNetLDMModel( attention_resolutions=[4, 2, 1], channel_mult=[1, 2, 4, 4], @@ -37,7 +39,7 @@ def main(args): conv_resample=True, dims=2, dropout=0, - image_size=32, + image_size=8, in_channels=4, model_channels=320, num_heads=8, @@ -51,7 +53,7 @@ def main(args): legacy=False, ) noise_scheduler = DDPMScheduler(timesteps=1000, tensor_format="pt") - optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + optimizer = bnb.optim.Adam8bit(model.parameters(), lr=args.lr) augmentations = Compose( [ @@ -59,14 +61,22 @@ def main(args): CenterCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), - Lambda(lambda x: x * 2 - 1), + Normalize([0.5], [0.5]), ] ) dataset = load_dataset(args.dataset, split="train") + text_encoder = pipeline.bert.eval() + vqvae = pipeline.vqvae.eval() + def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] - return {"input": images} + text_inputs = pipeline.tokenizer(examples["caption"], padding="max_length", max_length=77, return_tensors="pt") + with torch.no_grad(): + text_embeddings = accelerator.unwrap_model(text_encoder)(text_inputs.input_ids.cpu()).last_hidden_state + images = 1 / 0.18215 * torch.stack(images, dim=0) + latents = accelerator.unwrap_model(vqvae).encode(images.cpu()).mode() + return {"images": images, "text_embeddings": text_embeddings, "latents": latents} dataset.set_transform(transforms) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) @@ -78,9 +88,11 @@ def main(args): num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps, ) - model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - model, optimizer, train_dataloader, lr_scheduler + model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + model, text_encoder, vqvae, optimizer, train_dataloader, lr_scheduler ) + text_encoder = text_encoder.cpu() + vqvae = vqvae.cpu() if args.push_to_hub: repo = init_git_repo(args, at_init=True) @@ -98,29 +110,31 @@ def main(args): logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") + global_step = 0 for epoch in range(args.num_epochs): model.train() with tqdm(total=len(train_dataloader), unit="ba") as pbar: pbar.set_description(f"Epoch {epoch}") for step, batch in enumerate(train_dataloader): - clean_images = batch["input"] - noise_samples = torch.randn(clean_images.shape).to(clean_images.device) - bsz = clean_images.shape[0] - timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_images.device).long() + clean_latents = batch["latents"] + noise_samples = torch.randn(clean_latents.shape).to(clean_latents.device) + bsz = clean_latents.shape[0] + timesteps = torch.randint(0, noise_scheduler.timesteps, (bsz,), device=clean_latents.device).long() - # add noise onto the clean images according to the noise magnitude at each timestep + # add noise onto the clean latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_images = noise_scheduler.training_step(clean_images, noise_samples, timesteps) + noisy_latents = noise_scheduler.training_step(clean_latents, noise_samples, timesteps) if step % args.gradient_accumulation_steps != 0: with accelerator.no_sync(model): - output = model(noisy_images, timesteps) + output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) # predict the noise residual loss = F.mse_loss(output, noise_samples) loss = loss / args.gradient_accumulation_steps accelerator.backward(loss) + optimizer.step() else: - output = model(noisy_images, timesteps) + output = model(noisy_latents, timesteps, context=batch["text_embeddings"]) # predict the noise residual loss = F.mse_loss(output, noise_samples) loss = loss / args.gradient_accumulation_steps @@ -131,24 +145,25 @@ def main(args): optimizer.zero_grad() pbar.update(1) pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"]) + global_step += 1 - optimizer.step() - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() # Generate a sample image for visual inspection - if args.local_rank in [-1, 0]: + if accelerator.is_main_process: model.eval() with torch.no_grad(): - pipeline = DDPM(unet=unwrap_model(model), noise_scheduler=noise_scheduler) + pipeline.unet = accelerator.unwrap_model(model) generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) - image = pipeline(generator=generator) + image = pipeline( + ["a clip art of a corgi"], generator=generator, eta=0.3, guidance_scale=6.0, num_inference_steps=50 + ) # process image to PIL image_processed = image.cpu().permute(0, 2, 3, 1) - image_processed = (image_processed + 1.0) * 127.5 + image_processed = image_processed * 255.0 image_processed = image_processed.type(torch.uint8).numpy() image_pil = PIL.Image.fromarray(image_processed[0]) @@ -162,20 +177,19 @@ def main(args): push_to_hub(args, pipeline, repo, commit_message=f"Epoch {epoch}", blocking=False) else: pipeline.save_pretrained(args.output_dir) - if is_distributed: - torch.distributed.barrier() + accelerator.wait_for_everyone() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--local_rank", type=int, default=-1) - parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories") - parser.add_argument("--output_dir", type=str, default="ddpm-model") + parser.add_argument("--dataset", type=str, default="fusing/dog_captions") + parser.add_argument("--output_dir", type=str, default="ldm-text2image") parser.add_argument("--overwrite_output_dir", action="store_true") - parser.add_argument("--resolution", type=int, default=64) - parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--resolution", type=int, default=128) + parser.add_argument("--batch_size", type=int, default=1) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--gradient_accumulation_steps", type=int, default=1) + parser.add_argument("--gradient_accumulation_steps", type=int, default=16) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--push_to_hub", action="store_true") From af6c143919fa122b57f11851a89ac1c63b9a272b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 15:34:11 +0000 Subject: [PATCH 4/6] remove einops --- src/diffusers/models/resnet.py | 9 ++--- src/diffusers/models/unet_grad_tts.py | 18 +++++----- src/diffusers/models/unet_ldm.py | 24 +++++++------ src/diffusers/models/unet_rl.py | 50 +++++++++++++++++++-------- 4 files changed, 63 insertions(+), 38 deletions(-) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 04e3735d60..a56437ad85 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import torch.nn.functional as F @@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs): return nn.Conv3d(*args, **kwargs) raise ValueError(f"unsupported dimensions: {dims}") + def conv_transpose_nd(dims, *args, **kwargs): """ Create a 1D, 2D, or 3D convolution module. @@ -81,15 +81,15 @@ class Upsample(nn.Module): assert x.shape[1] == self.channels if self.use_conv_transpose: return self.conv(x) - + if self.dims == 3: x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") else: x = F.interpolate(x, scale_factor=2.0, mode="nearest") - + if self.use_conv: x = self.conv(x) - + return x @@ -138,6 +138,7 @@ class UNetUpsample(nn.Module): x = self.conv(x) return x + class GlideUpsample(nn.Module): """ An upsampling layer with an optional convolution. diff --git a/src/diffusers/models/unet_grad_tts.py b/src/diffusers/models/unet_grad_tts.py index 36bcce53e9..9304732e15 100644 --- a/src/diffusers/models/unet_grad_tts.py +++ b/src/diffusers/models/unet_grad_tts.py @@ -1,12 +1,5 @@ import torch - -try: - from einops import rearrange -except: - print("Einops is not installed") - pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding @@ -81,6 +74,7 @@ class LinearAttention(torch.nn.Module): def __init__(self, dim, heads=4, dim_head=32): super(LinearAttention, self).__init__() self.heads = heads + self.dim_head = dim_head hidden_dim = dim_head * heads self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1) @@ -88,11 +82,17 @@ class LinearAttention(torch.nn.Module): def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) - q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + # q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3) + q, k, v = ( + qkv.reshape(b, 3, self.heads, self.dim_head, h, w) + .permute(1, 0, 2, 3, 4, 5) + .reshape(3, b, self.heads, self.dim_head, -1) + ) k = k.softmax(dim=-1) context = torch.einsum("bhdn,bhen->bhde", k, v) out = torch.einsum("bhde,bhdn->bhen", context, q) - out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + # out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w) + out = out.reshape(b, self.heads, self.dim_head, h, w).reshape(b, self.heads * self.dim_head, h, w) return self.to_out(out) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index bd70913ff2..4403309d51 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -6,14 +6,15 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F + from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin from .embeddings import get_timestep_embedding -#try: +# try: # from einops import rearrange, repeat -#except: +# except: # print("Einops is not installed") # pass @@ -80,7 +81,7 @@ def Normalize(in_channels): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) -#class LinearAttention(nn.Module): +# class LinearAttention(nn.Module): # def __init__(self, dim, heads=4, dim_head=32): # super().__init__() # self.heads = heads @@ -100,7 +101,7 @@ def Normalize(in_channels): # return self.to_out(out) # -#class SpatialSelfAttention(nn.Module): +# class SpatialSelfAttention(nn.Module): # def __init__(self, in_channels): # super().__init__() # self.in_channels = in_channels @@ -118,7 +119,7 @@ def Normalize(in_channels): # k = self.k(h_) # v = self.v(h_) # - # compute attention +# compute attention # b, c, h, w = q.shape # q = rearrange(q, "b c h w -> b (h w) c") # k = rearrange(k, "b c h w -> b c (h w)") @@ -127,7 +128,7 @@ def Normalize(in_channels): # w_ = w_ * (int(c) ** (-0.5)) # w_ = torch.nn.functional.softmax(w_, dim=2) # - # attend to values +# attend to values # v = rearrange(v, "b c h w -> b c (h w)") # w_ = rearrange(w_, "b i j -> b j i") # h_ = torch.einsum("bij,bjk->bik", v, w_) @@ -137,6 +138,7 @@ def Normalize(in_channels): # return x + h_ # + class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): super().__init__() @@ -176,7 +178,7 @@ class CrossAttention(nn.Module): k = self.to_k(context) v = self.to_v(context) -# q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + # q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) q = self.reshape_heads_to_batch_dim(q) k = self.reshape_heads_to_batch_dim(k) @@ -185,12 +187,12 @@ class CrossAttention(nn.Module): sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): -# mask = rearrange(mask, "b ... -> b (...)") + # mask = rearrange(mask, "b ... -> b (...)") maks = mask.reshape(batch_size, -1) max_neg_value = -torch.finfo(sim.dtype).max -# mask = repeat(mask, "b j -> (b h) () j", h=h) + # mask = repeat(mask, "b j -> (b h) () j", h=h) mask = mask[:, None, :].repeat(h, 1, 1) -# x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + # x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of @@ -198,7 +200,7 @@ class CrossAttention(nn.Module): out = torch.einsum("b i j, b j d -> b i d", attn, v) out = self.reshape_batch_dim_to_heads(out) -# out = rearrange(out, "(b h) n d -> b n (h d)", h=h) + # out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return self.to_out(out) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index 28fea5753c..a6025eeb3b 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -5,18 +5,19 @@ import math import torch import torch.nn as nn - -try: - import einops - from einops.layers.torch import Rearrange -except: - print("Einops is not installed") - pass - from ..configuration_utils import ConfigMixin from ..modeling_utils import ModelMixin +# try: +# import einops +# from einops.layers.torch import Rearrange +# except: +# print("Einops is not installed") +# pass + + + class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() @@ -50,6 +51,21 @@ class Upsample1d(nn.Module): return self.conv(x) +class RearrangeDim(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + class Conv1dBlock(nn.Module): """ Conv1d --> GroupNorm --> Mish @@ -60,9 +76,11 @@ class Conv1dBlock(nn.Module): self.block = nn.Sequential( nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), - Rearrange("batch channels horizon -> batch channels 1 horizon"), + RearrangeDim(), + # Rearrange("batch channels horizon -> batch channels 1 horizon"), nn.GroupNorm(n_groups, out_channels), - Rearrange("batch channels 1 horizon -> batch channels horizon"), + RearrangeDim(), + # Rearrange("batch channels 1 horizon -> batch channels horizon"), nn.Mish(), ) @@ -84,7 +102,8 @@ class ResidualTemporalBlock(nn.Module): self.time_mlp = nn.Sequential( nn.Mish(), nn.Linear(embed_dim, out_channels), - Rearrange("batch t -> batch t 1"), + RearrangeDim(), + # Rearrange("batch t -> batch t 1"), ) self.residual_conv = ( @@ -184,7 +203,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): x : [ batch x horizon x transition ] """ - x = einops.rearrange(x, "b h t -> b t h") + # x = einops.rearrange(x, "b h t -> b t h") + x = x.permute(0, 2, 1) t = self.time_mlp(time) h = [] @@ -206,7 +226,8 @@ class TemporalUNet(ModelMixin, ConfigMixin): # (nn.Module): x = self.final_conv(x) - x = einops.rearrange(x, "b t h -> b h t") + # x = einops.rearrange(x, "b t h -> b h t") + x = x.permute(0, 2, 1) return x @@ -263,7 +284,8 @@ class TemporalValue(nn.Module): x : [ batch x horizon x transition ] """ - x = einops.rearrange(x, "b h t -> b t h") + # x = einops.rearrange(x, "b h t -> b t h") + x = x.permute(0, 2, 1) t = self.time_mlp(time) From 932ce05d977a2e3f4108bca25f96b21b85b6733d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 15:39:41 +0000 Subject: [PATCH 5/6] cancel einops --- examples/train_unconditional.py | 2 +- src/diffusers/models/unet_ldm.py | 7 ------ src/diffusers/models/unet_rl.py | 1 - src/diffusers/schedulers/scheduling_utils.py | 24 ++++++++------------ tests/test_modeling_utils.py | 3 ++- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/examples/train_unconditional.py b/examples/train_unconditional.py index fe45f2a5fa..5398c78268 100644 --- a/examples/train_unconditional.py +++ b/examples/train_unconditional.py @@ -173,7 +173,7 @@ if __name__ == "__main__": parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument("--ema_inv_gamma", type=float, default=1.0) - parser.add_argument("--ema_power", type=float, default=3/4) + parser.add_argument("--ema_power", type=float, default=3 / 4) parser.add_argument("--ema_max_decay", type=float, default=0.999) parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_token", type=str, default=None) diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index c76188cce4..378fdd57a2 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -13,13 +13,6 @@ from .embeddings import get_timestep_embedding from .resnet import Upsample -# try: -# from einops import rearrange, repeat -# except: -# print("Einops is not installed") -# pass - - def exists(val): return val is not None diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index a6025eeb3b..a0b8c5e47a 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -17,7 +17,6 @@ from ..modeling_utils import ModelMixin # pass - class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 4cfbc5e59d..7c5972434b 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import numpy as np import torch -from typing import Union - SCHEDULER_CONFIG_NAME = "scheduler_config.json" @@ -53,20 +53,16 @@ class SchedulerMixin: raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - def match_shape( - self, - values: Union[np.ndarray, torch.Tensor], - broadcast_array: Union[np.ndarray, torch.Tensor] - ): + def match_shape(self, values: Union[np.ndarray, torch.Tensor], broadcast_array: Union[np.ndarray, torch.Tensor]): """ - Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. + Turns a 1-D array into an array or tensor with len(broadcast_array.shape) dims. - Args: - timesteps: an array or tensor of values to extract. - broadcast_array: an array with a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - Returns: - a tensor of shape [batch_size, 1, ...] where the shape has K dims. + Args: + timesteps: an array or tensor of values to extract. + broadcast_array: an array with a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + Returns: + a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ tensor_format = getattr(self, "tensor_format", "pt") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 85c7ef5d0e..aa40513621 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -21,7 +21,8 @@ import unittest import numpy as np import torch -from diffusers import ( # GradTTSPipeline, +from diffusers import ( + GradTTSPipeline, BDDMPipeline, DDIMPipeline, DDIMScheduler, From 4261c3aadfc23ee5b123b80ab7d8680a013acb66 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 27 Jun 2022 15:59:04 +0000 Subject: [PATCH 6/6] Make style --- Makefile | 11 +- src/diffusers/hub_utils.py | 13 +- src/diffusers/modeling_utils.py | 28 +- src/diffusers/models/embeddings.py | 8 +- src/diffusers/models/resnet.py | 22 +- src/diffusers/models/unet_glide.py | 71 ++--- src/diffusers/models/unet_ldm.py | 114 +++----- src/diffusers/models/unet_rl.py | 6 +- .../models/unet_sde_score_estimation.py | 97 +++---- src/diffusers/pipelines/grad_tts_utils.py | 6 +- src/diffusers/pipelines/pipeline_bddm.py | 6 +- src/diffusers/pipelines/pipeline_glide.py | 5 +- src/diffusers/pipelines/pipeline_grad_tts.py | 2 +- .../pipelines/pipeline_latent_diffusion.py | 12 +- src/diffusers/schedulers/scheduling_ddim.py | 11 +- src/diffusers/schedulers/scheduling_ddpm.py | 11 +- src/diffusers/schedulers/scheduling_pndm.py | 11 +- src/diffusers/training_utils.py | 9 +- src/diffusers/utils/__init__.py | 12 +- ...rmers_and_inflect_and_unidecode_objects.py | 2 +- .../utils/dummy_transformers_objects.py | 4 +- src/diffusers/utils/logging.py | 4 +- tests/test_modeling_utils.py | 2 +- utils/check_copies.py | 24 +- utils/custom_init_isort.py | 250 ++++++++++++++++++ 25 files changed, 451 insertions(+), 290 deletions(-) create mode 100644 utils/custom_init_isort.py diff --git a/Makefile b/Makefile index 83a84fe461..ec8237e15f 100644 --- a/Makefile +++ b/Makefile @@ -34,13 +34,9 @@ autogenerate_code: deps_table_update # Check that the repo is in a good state repo-consistency: - python utils/check_copies.py - python utils/check_table.py python utils/check_dummies.py python utils/check_repo.py python utils/check_inits.py - python utils/check_config_docstrings.py - python utils/tests_fetcher.py --sanity_check # this target runs checks on all files @@ -48,14 +44,13 @@ quality: black --check --preview $(check_dirs) isort --check-only $(check_dirs) flake8 $(check_dirs) - doc-builder style src/transformers docs/source --max_len 119 --check_only --path_to_docs docs/source + doc-builder style src/diffusers docs/source --max_len 119 --check_only --path_to_docs docs/source # Format source code automatically and check is there are any problems left that need manual fixing extra_style_checks: python utils/custom_init_isort.py - python utils/sort_auto_mappings.py - doc-builder style src/transformers docs/source --max_len 119 --path_to_docs docs/source + doc-builder style src/diffusers docs/source --max_len 119 --path_to_docs docs/source # this target runs checks on all files and potentially modifies some of them @@ -73,8 +68,6 @@ fixup: modified_only_fixup extra_style_checks autogenerate_code repo-consistency fix-copies: python utils/check_dummies.py --fix_and_overwrite - python utils/check_table.py --fix_and_overwrite - python utils/check_copies.py --fix_and_overwrite # Run tests for the library diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index c2d1e34f3e..2ab2ff289a 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -47,12 +47,11 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def init_git_repo(args, at_init: bool = False): """ - Initializes a git repo in `args.hub_model_id`. Args: + Initializes a git repo in `args.hub_model_id`. at_init (`bool`, *optional*, defaults to `False`): - Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is - `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped - out. + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` + and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ if args.local_rank not in [-1, 0]: return @@ -102,8 +101,8 @@ def push_to_hub( **kwargs, ) -> str: """ - Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. Parameters: + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): @@ -111,8 +110,8 @@ def push_to_hub( kwargs: Additional keyword arguments passed along to [`create_model_card`]. Returns: - The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of - the commit and an object to track the progress of the commit if `blocking=True` + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of the + commit and an object to track the progress of the commit if `blocking=True` """ if args.hub_model_id is None: diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 0b3d072b70..aa60ffa936 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -123,16 +123,16 @@ class ModelMixin(torch.nn.Module): r""" Base class for all models. - [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, - downloading and saving models as well as a few methods common to all models to: + [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading + and saving models as well as a few methods common to all models to: - resize the input embeddings, - prune heads in the self-attention heads. Class attributes (overridden by derived classes): - - **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class - for this model architecture. + - **config_class** ([`ConfigMixin`]) -- A subclass of [`ConfigMixin`] to use as configuration class for this + model architecture. - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: @@ -227,8 +227,8 @@ class ModelMixin(torch.nn.Module): - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - - A path to a *directory* containing model weights saved using - [`~ModelMixin.save_pretrained`], e.g., `./my_model_directory/`. + - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], + e.g., `./my_model_directory/`. config (`Union[ConfigMixin, str, os.PathLike]`, *optional*): Can be either: @@ -236,13 +236,13 @@ class ModelMixin(torch.nn.Module): - an instance of a class derived from [`ConfigMixin`], - a string or path valid as input to [`~ConfigMixin.from_pretrained`]. - ConfigMixinuration for the model to use instead of an automatically loaded configuration. ConfigMixinuration can - be automatically loaded when: + ConfigMixinuration for the model to use instead of an automatically loaded configuration. + ConfigMixinuration can be automatically loaded when: - The model is a model provided by the library (loaded with the *model id* string of a pretrained model). - - The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the - save directory. + - The model was saved using [`~ModelMixin.save_pretrained`] and is reloaded by supplying the save + directory. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a configuration JSON file named *config.json* is found in the directory. cache_dir (`Union[str, os.PathLike]`, *optional*): @@ -292,10 +292,10 @@ class ModelMixin(torch.nn.Module): underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. + initialization function ([`~ConfigMixin.from_pretrained`]). Each key of `kwargs` that corresponds + to a configuration attribute will be used to override said attribute with the supplied `kwargs` + value. Remaining keys that do not correspond to any configuration attribute will be passed to the + underlying model's `__init__` function. diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index f31b64ee5c..e70f39319e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -22,14 +22,12 @@ def get_timestep_embedding( timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1, scale=1, max_period=10000 ): """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - Create sinusoidal timestep embeddings. + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param embedding_dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 9e5ef17641..6560d34559 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -58,9 +58,8 @@ class Upsample(nn.Module): """ An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ @@ -97,9 +96,8 @@ class Downsample(nn.Module): """ A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ @@ -143,9 +141,8 @@ class GlideUpsample(nn.Module): """ An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. """ @@ -171,10 +168,9 @@ class GlideUpsample(nn.Module): class LDMUpsample(nn.Module): """ - An upsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + An upsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param + use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. + If 3D, then upsampling occurs in the inner-two dimensions. """ diff --git a/src/diffusers/models/unet_glide.py b/src/diffusers/models/unet_glide.py index 9a50b9cb52..d357d0cc8a 100644 --- a/src/diffusers/models/unet_glide.py +++ b/src/diffusers/models/unet_glide.py @@ -82,8 +82,7 @@ def normalization(channels, swish=0.0): """ Make a standard normalization layer, with an optional swish activation. - :param channels: number of input channels. - :return: an nn.Module for normalization. + :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) @@ -111,8 +110,7 @@ class TimestepBlock(nn.Module): class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. + A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, encoder_out=None): @@ -130,9 +128,8 @@ class Downsample(nn.Module): """ A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + :param channels: channels in the inputs and outputs. :param use_conv: a bool determining if a convolution is + applied. :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. """ @@ -158,17 +155,13 @@ class ResBlock(TimestepBlock): """ A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. + :param channels: the number of input channels. :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. :param out_channels: if specified, the number of out channels. :param + use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing + on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for + downsampling. """ def __init__( @@ -235,8 +228,7 @@ class ResBlock(TimestepBlock): """ Apply the block to a Tensor, conditioned on a timestep embedding. - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :param x: an [N x C x ...] Tensor of features. :param emb: an [N x emb_channels] Tensor of timestep embeddings. :return: an [N x C x ...] Tensor of outputs. """ if self.updown: @@ -320,8 +312,8 @@ class QKVAttention(nn.Module): """ Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x T] tensor after + attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 @@ -343,29 +335,24 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. + :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param + out_channels: channels in the output Tensor. :param num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and + attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x + downsampling, attention will be used. + :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param + conv_resample: if True, use learned convolutions for upsampling and downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be + :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this + model will be class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use + :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention + heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks + for up/downsampling. """ def __init__( @@ -571,10 +558,8 @@ class GlideUNetModel(ModelMixin, ConfigMixin): """ Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. + :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch of timesteps. :param y: an [N] + Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ hs = [] diff --git a/src/diffusers/models/unet_ldm.py b/src/diffusers/models/unet_ldm.py index 378fdd57a2..f8a8602d2f 100644 --- a/src/diffusers/models/unet_ldm.py +++ b/src/diffusers/models/unet_ldm.py @@ -222,11 +222,8 @@ class BasicTransformerBlock(nn.Module): class SpatialTransformer(nn.Module): """ - Transformer block for image-like data. - First, project the input (aka embedding) - and reshape to b, t, d. - Then apply standard transformer action. - Finally, reshape to image + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image """ def __init__(self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None): @@ -331,8 +328,7 @@ def normalization(channels, swish=0.0): """ Make a standard normalization layer, with an optional swish activation. - :param channels: number of input channels. - :return: an nn.Module for normalization. + :param channels: number of input channels. :return: an nn.Module for normalization. """ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) @@ -382,8 +378,7 @@ class TimestepBlock(nn.Module): class TimestepEmbedSequential(nn.Sequential, TimestepBlock): """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. + A sequential module that passes timestep embeddings to the children that support it as an extra input. """ def forward(self, x, emb, context=None): @@ -399,10 +394,9 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock): class Downsample(nn.Module): """ - A downsampling layer with an optional convolution. - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + A downsampling layer with an optional convolution. :param channels: channels in the inputs and outputs. :param + use_conv: a bool determining if a convolution is applied. :param dims: determines if the signal is 1D, 2D, or 3D. + If 3D, then downsampling occurs in the inner-two dimensions. """ @@ -426,18 +420,14 @@ class Downsample(nn.Module): class ResBlock(TimestepBlock): """ - A residual block that can optionally change the number of channels. - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. + A residual block that can optionally change the number of channels. :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. :param dropout: the rate of dropout. :param + out_channels: if specified, the number of out channels. :param use_conv: if True and out_channels is specified, use + a spatial + convolution instead of a smaller 1x1 convolution to change the channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. :param use_checkpoint: if True, use gradient checkpointing + on this module. :param up: if True, use this block for upsampling. :param down: if True, use this block for + downsampling. """ def __init__( @@ -525,8 +515,8 @@ class ResBlock(TimestepBlock): class AttentionBlock(nn.Module): """ - An attention block that allows spatial positions to attend to each other. - Originally ported from here, but adapted to the N-d case. + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. """ @@ -575,9 +565,8 @@ class QKVAttention(nn.Module): def forward(self, qkv): """ - Apply QKV attention. - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. + Apply QKV attention. :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x + T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 @@ -600,13 +589,9 @@ class QKVAttention(nn.Module): def count_flops_attn(model, _x, y): """ - A counter for the `thop` package to count the operations in an - attention operation. - Meant to be used like: + A counter for the `thop` package to count the operations in an attention operation. Meant to be used like: macs, params = thop.profile( - model, - inputs=(inputs, timestamps), - custom_ops={QKVAttention: QKVAttention.count_flops}, + model, inputs=(inputs, timestamps), custom_ops={QKVAttention: QKVAttention.count_flops}, ) """ b, c, *spatial = y[0].shape @@ -629,9 +614,8 @@ class QKVAttentionLegacy(nn.Module): def forward(self, qkv): """ - Apply QKV attention. - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. + Apply QKV attention. :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. :return: an [N x (H * C) x + T] tensor after attention. """ bs, width, length = qkv.shape assert width % (3 * self.n_heads) == 0 @@ -650,31 +634,25 @@ class QKVAttentionLegacy(nn.Module): class UNetLDMModel(ModelMixin, ConfigMixin): """ - The full UNet model with attention and timestep embedding. - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and + The full UNet model with attention and timestep embedding. :param in_channels: channels in the input Tensor. :param + model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param + num_res_blocks: number of residual blocks per downsample. :param attention_resolutions: a collection of downsample + rates at which + attention will take place. May be a set, list, or tuple. For example, if this contains 4, then at 4x + downsampling, attention will be used. + :param dropout: the dropout probability. :param channel_mult: channel multiplier for each level of the UNet. :param + conv_resample: if True, use learned convolutions for upsampling and downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be + :param dims: determines if the signal is 1D, 2D, or 3D. :param num_classes: if specified (as an int), then this + model will be class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use + :param use_checkpoint: use gradient checkpointing to reduce memory usage. :param num_heads: the number of attention + heads in each attention layer. :param num_heads_channels: if specified, ignore num_heads and instead use a fixed channel width per attention head. :param num_heads_upsample: works with num_heads to set a different number of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. :param resblock_updown: use residual blocks + for up/downsampling. :param use_new_attention_order: use a different attention pattern for potentially increased efficiency. """ @@ -975,12 +953,9 @@ class UNetLDMModel(ModelMixin, ConfigMixin): def forward(self, x, timesteps=None, context=None, y=None, **kwargs): """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. + Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch + of timesteps. :param context: conditioning plugged in via crossattn :param y: an [N] Tensor of labels, if + class-conditional. :return: an [N x C x ...] Tensor of outputs. """ assert (y is not None) == ( self.num_classes is not None @@ -1012,8 +987,7 @@ class UNetLDMModel(ModelMixin, ConfigMixin): class EncoderUNetModel(nn.Module): """ - The half UNet model with attention and timestep embedding. - For usage, see UNet. + The half UNet model with attention and timestep embedding. For usage, see UNet. """ def __init__( @@ -1197,10 +1171,8 @@ class EncoderUNetModel(nn.Module): def forward(self, x, timesteps): """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. + Apply the model to an input batch. :param x: an [N x C x ...] Tensor of inputs. :param timesteps: a 1-D batch + of timesteps. :return: an [N x K] Tensor of outputs. """ emb = self.time_embed( get_timestep_embedding(timesteps, self.model_channels, flip_sin_to_cos=True, downscale_freq_shift=0) diff --git a/src/diffusers/models/unet_rl.py b/src/diffusers/models/unet_rl.py index a0b8c5e47a..9c0c77130c 100644 --- a/src/diffusers/models/unet_rl.py +++ b/src/diffusers/models/unet_rl.py @@ -111,10 +111,8 @@ class ResidualTemporalBlock(nn.Module): def forward(self, x, t): """ - x : [ batch_size x inp_channels x horizon ] - t : [ batch_size x embed_dim ] - returns: - out : [ batch_size x out_channels x horizon ] + x : [ batch_size x inp_channels x horizon ] t : [ batch_size x embed_dim ] returns: out : [ batch_size x + out_channels x horizon ] """ out = self.blocks[0](x) + self.time_mlp(t) out = self.blocks[1](out) diff --git a/src/diffusers/models/unet_sde_score_estimation.py b/src/diffusers/models/unet_sde_score_estimation.py index 83700c4b63..44c635922d 100644 --- a/src/diffusers/models/unet_sde_score_estimation.py +++ b/src/diffusers/models/unet_sde_score_estimation.py @@ -136,26 +136,21 @@ def naive_downsample_2d(x, factor=2): def upsample_conv_2d(x, w, k=None, factor=2, gain=1): """Fused `upsample_2d()` followed by `tf.nn.conv2d()`. - Padding is performed only once at the beginning, not between the - operations. - The fused op is considerably more efficient than performing the same - calculation - using standard TensorFlow ops. It supports gradients of arbitrary order. Args: - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - w: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = - x.shape[0] // numGroups`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to - nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H * factor, W * factor]` or - `[N, H * factor, W * factor, C]`, and same datatype as `x`. + Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same datatype as + `x`. """ assert isinstance(factor, int) and factor >= 1 @@ -208,25 +203,21 @@ def upsample_conv_2d(x, w, k=None, factor=2, gain=1): def conv_downsample_2d(x, w, k=None, factor=2, gain=1): """Fused `tf.nn.conv2d()` followed by `downsample_2d()`. - Padding is performed only once at the beginning, not between the operations. - The fused op is considerably more efficient than performing the same - calculation - using standard TensorFlow ops. It supports gradients of arbitrary order. Args: - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + Padding is performed only once at the beginning, not between the operations. The fused op is considerably more + efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary + order. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - w: Weight tensor of the shape `[filterH, filterW, inChannels, - outChannels]`. Grouped convolution can be performed by `inChannels = - x.shape[0] // numGroups`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to - average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + w: Weight tensor of the shape `[filterH, filterW, inChannels, + outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: - Tensor of the shape `[N, C, H // factor, W // factor]` or - `[N, H // factor, W // factor, C]`, and same datatype as `x`. + Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype + as `x`. """ assert isinstance(factor, int) and factor >= 1 @@ -258,22 +249,16 @@ def _shape(x, dim): def upsample_2d(x, k=None, factor=2, gain=1): r"""Upsample a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` - and upsamples each image with the given filter. The filter is normalized so - that - if the input pixels are constant, they will be scaled by the specified - `gain`. - Pixels outside the image are assumed to be zero, and the filter is padded - with - zeros so that its shape is a multiple of the upsampling factor. Args: - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given + filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified + `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a: + multiple of the upsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to - nearest-neighbor upsampling. - factor: Integer upsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. + factor: Integer upsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H * factor, W * factor]` @@ -289,22 +274,16 @@ def upsample_2d(x, k=None, factor=2, gain=1): def downsample_2d(x, k=None, factor=2, gain=1): r"""Downsample a batch of 2D images with the given filter. - Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` - and downsamples each image with the given filter. The filter is normalized - so that - if the input pixels are constant, they will be scaled by the specified - `gain`. - Pixels outside the image are assumed to be zero, and the filter is padded - with - zeros so that its shape is a multiple of the downsampling factor. Args: - x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, + Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the + given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the + specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its + shape is a multiple of the downsampling factor. + x: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. - k: FIR filter of the shape `[firH, firW]` or `[firN]` - (separable). The default is `[1] * factor`, which corresponds to - average pooling. - factor: Integer downsampling factor (default: 2). - gain: Scaling factor for signal magnitude (default: 1.0). + k: FIR filter of the shape `[firH, firW]` or `[firN]` + (separable). The default is `[1] * factor`, which corresponds to average pooling. + factor: Integer downsampling factor (default: 2). gain: Scaling factor for signal magnitude (default: 1.0). Returns: Tensor of the shape `[N, C, H // factor, W // factor]` diff --git a/src/diffusers/pipelines/grad_tts_utils.py b/src/diffusers/pipelines/grad_tts_utils.py index 15995b85c8..f36f31c5a9 100644 --- a/src/diffusers/pipelines/grad_tts_utils.py +++ b/src/diffusers/pipelines/grad_tts_utils.py @@ -290,7 +290,7 @@ def normalize_numbers(text): return text -""" from https://github.com/keithito/tacotron """ +""" from https://github.com/keithito/tacotron""" _pad = "_" @@ -322,8 +322,8 @@ def get_arpabet(word, dictionary): def text_to_sequence(text, cleaner_names=[english_cleaners], dictionary=None): """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + The text can optionally have ARPAbet sequences enclosed in curly braces embedded in it. For example, "Turn left on + {HH AW1 S S T AH0 N} Street." Args: text: string to convert to a sequence diff --git a/src/diffusers/pipelines/pipeline_bddm.py b/src/diffusers/pipelines/pipeline_bddm.py index 8b24cb9ceb..09120fdab0 100644 --- a/src/diffusers/pipelines/pipeline_bddm.py +++ b/src/diffusers/pipelines/pipeline_bddm.py @@ -29,8 +29,7 @@ from ..pipeline_utils import DiffusionPipeline def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): """ Embed a diffusion step $t$ into a higher dimensional space - E.g. the embedding vector in the 128-dimensional space is - [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), + E.g. the embedding vector in the 128-dimensional space is [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] Parameters: @@ -53,8 +52,7 @@ def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): """ -Below scripts were borrowed from -https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py +Below scripts were borrowed from https://github.com/philsyn/DiffWave-Vocoder/blob/master/WaveNet.py """ diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 8680b7542a..9a67790b35 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -699,9 +699,8 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): """ Extract values from a 1-D numpy array for a batch of indices. - :param arr: the 1-D numpy array. - :param timesteps: a tensor of indices into the array to extract. - :param broadcast_shape: a larger shape of K dimensions with the batch + :param arr: the 1-D numpy array. :param timesteps: a tensor of indices into the array to extract. :param + broadcast_shape: a larger shape of K dimensions with the batch dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ diff --git a/src/diffusers/pipelines/pipeline_grad_tts.py b/src/diffusers/pipelines/pipeline_grad_tts.py index 743104e658..93770fe21e 100644 --- a/src/diffusers/pipelines/pipeline_grad_tts.py +++ b/src/diffusers/pipelines/pipeline_grad_tts.py @@ -1,4 +1,4 @@ -""" from https://github.com/jaywalnut310/glow-tts """ +""" from https://github.com/jaywalnut310/glow-tts""" import math diff --git a/src/diffusers/pipelines/pipeline_latent_diffusion.py b/src/diffusers/pipelines/pipeline_latent_diffusion.py index ffc8ae670c..fea7a287ed 100644 --- a/src/diffusers/pipelines/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/pipeline_latent_diffusion.py @@ -554,11 +554,9 @@ class LDMBertModel(LDMBertPreTrainedModel): def get_timestep_embedding(timesteps, embedding_dim): """ - This matches the implementation in Denoising Diffusion Probabilistic Models: - From Fairseq. - Build sinusoidal embeddings. - This matches the implementation in tensor2tensor, but differs slightly - from the description in Section 3.5 of "Attention Is All You Need". + This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal + embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section + 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 @@ -1055,8 +1053,8 @@ class Decoder(nn.Module): class VectorQuantizer(nn.Module): """ - Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly - avoids costly matrix multiplications and allows for post-hoc remapping of indices. + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly avoids costly matrix + multiplications and allows for post-hoc remapping of indices. """ # NOTE: due to a bug the beta term was applied to the wrong term. for diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index d11af4ec25..f626cb1ca5 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. + :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t + from 0 to 1 and + produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d908850dfe..d4230ff069 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -25,13 +25,12 @@ from .scheduling_utils import SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. + :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t + from 0 to 1 and + produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e7479d5497..8533ad6cd7 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -24,13 +24,12 @@ from .scheduling_utils import SchedulerMixin def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. + :param num_diffusion_timesteps: the number of betas to produce. :param alpha_bar: a lambda that takes an argument t + from 0 to 1 and + produces the cumulative product of (1-beta) up to that part of the diffusion process. :param max_beta: the maximum beta to use; use values lower than 1 to prevent singularities. """ diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 99fecaa07f..f81bf5cc03 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -20,11 +20,10 @@ class EMAModel: ): """ @crowsonkb's notes on EMA Warmup: - If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are - good values for models you plan to train for a million or more steps (reaches decay - factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models - you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at - 215.4k steps). + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan + to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), + gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 + at 215.4k steps). Args: inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. power (float): Exponential factor of EMA warmup. Default: 2/3. diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 470526a8b5..2c56ba4a8a 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -89,20 +89,20 @@ class RevisionNotFoundError(HTTPError): TRANSFORMERS_IMPORT_ERROR = """ -{0} requires the transformers library but it was not found in your environment. You can install it with pip: -`pip install transformers` +{0} requires the transformers library but it was not found in your environment. You can install it with pip: `pip +install transformers` """ UNIDECODE_IMPORT_ERROR = """ -{0} requires the unidecode library but it was not found in your environment. You can install it with pip: -`pip install Unidecode` +{0} requires the unidecode library but it was not found in your environment. You can install it with pip: `pip install +Unidecode` """ INFLECT_IMPORT_ERROR = """ -{0} requires the inflect library but it was not found in your environment. You can install it with pip: -`pip install inflect` +{0} requires the inflect library but it was not found in your environment. You can install it with pip: `pip install +inflect` """ diff --git a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py b/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py index 320a93134a..8c2aec218c 100644 --- a/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py +++ b/src/diffusers/utils/dummy_transformers_and_inflect_and_unidecode_objects.py @@ -3,7 +3,7 @@ from ..utils import DummyObject, requires_backends -class GradTTS(metaclass=DummyObject): +class GradTTSPipeline(metaclass=DummyObject): _backends = ["transformers", "inflect", "unidecode"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/dummy_transformers_objects.py b/src/diffusers/utils/dummy_transformers_objects.py index 1efb17297f..ac34367a3b 100644 --- a/src/diffusers/utils/dummy_transformers_objects.py +++ b/src/diffusers/utils/dummy_transformers_objects.py @@ -31,14 +31,14 @@ class UNetGradTTSModel(metaclass=DummyObject): requires_backends(self, ["transformers"]) -class Glide(metaclass=DummyObject): +class GlidePipeline(metaclass=DummyObject): _backends = ["transformers"] def __init__(self, *args, **kwargs): requires_backends(self, ["transformers"]) -class LatentDiffusion(metaclass=DummyObject): +class LatentDiffusionPipeline(metaclass=DummyObject): _backends = ["transformers"] def __init__(self, *args, **kwargs): diff --git a/src/diffusers/utils/logging.py b/src/diffusers/utils/logging.py index 63027f3267..1f2d0227b8 100644 --- a/src/diffusers/utils/logging.py +++ b/src/diffusers/utils/logging.py @@ -233,8 +233,8 @@ def disable_propagation() -> None: def enable_propagation() -> None: """ - Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to - prevent double logging if the root logger has been configured. + Enable propagation of the library log outputs. Please disable the HuggingFace Diffusers' default handler to prevent + double logging if the root logger has been configured. """ _configure_library_root_logger() diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index aa40513621..453b4fa285 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -22,7 +22,6 @@ import numpy as np import torch from diffusers import ( - GradTTSPipeline, BDDMPipeline, DDIMPipeline, DDIMScheduler, @@ -31,6 +30,7 @@ from diffusers import ( GlidePipeline, GlideSuperResUNetModel, GlideTextToImageUNetModel, + GradTTSPipeline, GradTTSScheduler, LatentDiffusionPipeline, NCSNpp, diff --git a/utils/check_copies.py b/utils/check_copies.py index 7565bfa51b..50f02cac65 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -24,7 +24,7 @@ from doc_builder.style_doc import style_docstrings_in_code # All paths are set with the intent you should run this script from the root of the repo with the command # python utils/check_copies.py -TRANSFORMERS_PATH = "src/transformers" +TRANSFORMERS_PATH = "src/diffusers" PATH_TO_DOCS = "docs/source/en" REPO_PATH = "." @@ -76,7 +76,7 @@ def _should_continue(line, indent): return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None -def find_code_in_transformers(object_name): +def find_code_in_diffusers(object_name): """Find and return the code source code of `object_name`.""" parts = object_name.split(".") i = 0 @@ -88,9 +88,7 @@ def find_code_in_transformers(object_name): if i < len(parts): module = os.path.join(module, parts[i]) if i >= len(parts): - raise ValueError( - f"`object_name` should begin with the name of a module of transformers but got {object_name}." - ) + raise ValueError(f"`object_name` should begin with the name of a module of diffusers but got {object_name}.") with open(os.path.join(TRANSFORMERS_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() @@ -121,7 +119,7 @@ def find_code_in_transformers(object_name): return "".join(code_lines) -_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)") +_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+diffusers\.(\S+\.\S+)\s*($|\S.*$)") _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") @@ -167,7 +165,7 @@ def is_copy_consistent(filename, overwrite=False): # There is some copied code here, let's retrieve the original. indent, object_name, replace_pattern = search.groups() - theoretical_code = find_code_in_transformers(object_name) + theoretical_code = find_code_in_diffusers(object_name) theoretical_indent = get_indent(theoretical_code) start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 @@ -235,7 +233,9 @@ def check_copies(overwrite: bool = False): + diff + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." ) - check_model_list_copy(overwrite=overwrite) + + +# check_model_list_copy(overwrite=overwrite) def check_full_copies(overwrite: bool = False): @@ -348,8 +348,8 @@ def convert_to_localized_md(model_list, localized_model_list, format_str): def convert_readme_to_index(model_list): - model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "") - return model_list.replace("https://huggingface.co/docs/transformers/", "") + model_list = model_list.replace("https://huggingface.co/docs/diffusers/main/", "") + return model_list.replace("https://huggingface.co/docs/diffusers/", "") def _find_text_in_file(filename, start_prompt, end_prompt): @@ -383,9 +383,9 @@ def check_model_list_copy(overwrite=False, max_per_line=119): # Fix potential doc links in the README with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f: readme = f.read() - new_readme = readme.replace("https://huggingface.co/transformers", "https://huggingface.co/docs/transformers") + new_readme = readme.replace("https://huggingface.co/diffusers", "https://huggingface.co/docs/diffusers") new_readme = new_readme.replace( - "https://huggingface.co/docs/main/transformers", "https://huggingface.co/docs/transformers/main" + "https://huggingface.co/docs/main/diffusers", "https://huggingface.co/docs/diffusers/main" ) if new_readme != readme: if overwrite: diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py new file mode 100644 index 0000000000..6501654872 --- /dev/null +++ b/utils/custom_init_isort.py @@ -0,0 +1,250 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + + +PATH_TO_TRANSFORMERS = "src/diffusers" + +# Pattern that looks at the indentation in a line. +_re_indent = re.compile(r"^(\s*)\S") +# Pattern that matches `"key":" and puts `key` in group 0. +_re_direct_key = re.compile(r'^\s*"([^"]+)":') +# Pattern that matches `_import_structure["key"]` and puts `key` in group 0. +_re_indirect_key = re.compile(r'^\s*_import_structure\["([^"]+)"\]') +# Pattern that matches `"key",` and puts `key` in group 0. +_re_strip_line = re.compile(r'^\s*"([^"]+)",\s*$') +# Pattern that matches any `[stuff]` and puts `stuff` in group 0. +_re_bracket_content = re.compile(r"\[([^\]]+)\]") + + +def get_indent(line): + """Returns the indent in `line`.""" + search = _re_indent.search(line) + return "" if search is None else search.groups()[0] + + +def split_code_in_indented_blocks(code, indent_level="", start_prompt=None, end_prompt=None): + """ + Split `code` into its indented blocks, starting at `indent_level`. If provided, begins splitting after + `start_prompt` and stops at `end_prompt` (but returns what's before `start_prompt` as a first block and what's + after `end_prompt` as a last block, so `code` is always the same as joining the result of this function). + """ + # Let's split the code into lines and move to start_index. + index = 0 + lines = code.split("\n") + if start_prompt is not None: + while not lines[index].startswith(start_prompt): + index += 1 + blocks = ["\n".join(lines[:index])] + else: + blocks = [] + + # We split into blocks until we get to the `end_prompt` (or the end of the block). + current_block = [lines[index]] + index += 1 + while index < len(lines) and (end_prompt is None or not lines[index].startswith(end_prompt)): + if len(lines[index]) > 0 and get_indent(lines[index]) == indent_level: + if len(current_block) > 0 and get_indent(current_block[-1]).startswith(indent_level + " "): + current_block.append(lines[index]) + blocks.append("\n".join(current_block)) + if index < len(lines) - 1: + current_block = [lines[index + 1]] + index += 1 + else: + current_block = [] + else: + blocks.append("\n".join(current_block)) + current_block = [lines[index]] + else: + current_block.append(lines[index]) + index += 1 + + # Adds current block if it's nonempty. + if len(current_block) > 0: + blocks.append("\n".join(current_block)) + + # Add final block after end_prompt if provided. + if end_prompt is not None and index < len(lines): + blocks.append("\n".join(lines[index:])) + + return blocks + + +def ignore_underscore(key): + "Wraps a `key` (that maps an object to string) to lower case and remove underscores." + + def _inner(x): + return key(x).lower().replace("_", "") + + return _inner + + +def sort_objects(objects, key=None): + "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." + # If no key is provided, we use a noop. + def noop(x): + return x + + if key is None: + key = noop + # Constants are all uppercase, they go first. + constants = [obj for obj in objects if key(obj).isupper()] + # Classes are not all uppercase but start with a capital, they go second. + classes = [obj for obj in objects if key(obj)[0].isupper() and not key(obj).isupper()] + # Functions begin with a lowercase, they go last. + functions = [obj for obj in objects if not key(obj)[0].isupper()] + + key1 = ignore_underscore(key) + return sorted(constants, key=key1) + sorted(classes, key=key1) + sorted(functions, key=key1) + + +def sort_objects_in_import(import_statement): + """ + Return the same `import_statement` but with objects properly sorted. + """ + # This inner function sort imports between [ ]. + def _replace(match): + imports = match.groups()[0] + if "," not in imports: + return f"[{imports}]" + keys = [part.strip().replace('"', "") for part in imports.split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + return "[" + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + "]" + + lines = import_statement.split("\n") + if len(lines) > 3: + # Here we have to sort internal imports that are on several lines (one per name): + # key: [ + # "object1", + # "object2", + # ... + # ] + + # We may have to ignore one or two lines on each side. + idx = 2 if lines[1].strip() == "[" else 1 + keys_to_sort = [(i, _re_strip_line.search(line).groups()[0]) for i, line in enumerate(lines[idx:-idx])] + sorted_indices = sort_objects(keys_to_sort, key=lambda x: x[1]) + sorted_lines = [lines[x[0] + idx] for x in sorted_indices] + return "\n".join(lines[:idx] + sorted_lines + lines[-idx:]) + elif len(lines) == 3: + # Here we have to sort internal imports that are on one separate line: + # key: [ + # "object1", "object2", ... + # ] + if _re_bracket_content.search(lines[1]) is not None: + lines[1] = _re_bracket_content.sub(_replace, lines[1]) + else: + keys = [part.strip().replace('"', "") for part in lines[1].split(",")] + # We will have a final empty element if the line finished with a comma. + if len(keys[-1]) == 0: + keys = keys[:-1] + lines[1] = get_indent(lines[1]) + ", ".join([f'"{k}"' for k in sort_objects(keys)]) + return "\n".join(lines) + else: + # Finally we have to deal with imports fitting on one line + import_statement = _re_bracket_content.sub(_replace, import_statement) + return import_statement + + +def sort_imports(file, check_only=True): + """ + Sort `_import_structure` imports in `file`, `check_only` determines if we only check or overwrite. + """ + with open(file, "r") as f: + code = f.read() + + if "_import_structure" not in code: + return + + # Blocks of indent level 0 + main_blocks = split_code_in_indented_blocks( + code, start_prompt="_import_structure = {", end_prompt="if TYPE_CHECKING:" + ) + + # We ignore block 0 (everything untils start_prompt) and the last block (everything after end_prompt). + for block_idx in range(1, len(main_blocks) - 1): + # Check if the block contains some `_import_structure`s thingy to sort. + block = main_blocks[block_idx] + block_lines = block.split("\n") + + # Get to the start of the imports. + line_idx = 0 + while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]: + # Skip dummy import blocks + if "import dummy" in block_lines[line_idx]: + line_idx = len(block_lines) + else: + line_idx += 1 + if line_idx >= len(block_lines): + continue + + # Ignore beginning and last line: they don't contain anything. + internal_block_code = "\n".join(block_lines[line_idx:-1]) + indent = get_indent(block_lines[1]) + # Slit the internal block into blocks of indent level 1. + internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent) + # We have two categories of import key: list or _import_structu[key].append/extend + pattern = _re_direct_key if "_import_structure" in block_lines[0] else _re_indirect_key + # Grab the keys, but there is a trap: some lines are empty or jsut comments. + keys = [(pattern.search(b).groups()[0] if pattern.search(b) is not None else None) for b in internal_blocks] + # We only sort the lines with a key. + keys_to_sort = [(i, key) for i, key in enumerate(keys) if key is not None] + sorted_indices = [x[0] for x in sorted(keys_to_sort, key=lambda x: x[1])] + + # We reorder the blocks by leaving empty lines/comments as they were and reorder the rest. + count = 0 + reorderded_blocks = [] + for i in range(len(internal_blocks)): + if keys[i] is None: + reorderded_blocks.append(internal_blocks[i]) + else: + block = sort_objects_in_import(internal_blocks[sorted_indices[count]]) + reorderded_blocks.append(block) + count += 1 + + # And we put our main block back together with its first and last line. + main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]]) + + if code != "\n".join(main_blocks): + if check_only: + return True + else: + print(f"Overwriting {file}.") + with open(file, "w") as f: + f.write("\n".join(main_blocks)) + + +def sort_imports_in_all_inits(check_only=True): + failures = [] + for root, _, files in os.walk(PATH_TO_TRANSFORMERS): + if "__init__.py" in files: + result = sort_imports(os.path.join(root, "__init__.py"), check_only=check_only) + if result: + failures = [os.path.join(root, "__init__.py")] + if len(failures) > 0: + raise ValueError(f"Would overwrite {len(failures)} files, run `make style`.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check_only", action="store_true", help="Whether to only check or fix style.") + args = parser.parse_args() + + sort_imports_in_all_inits(check_only=args.check_only)