From 37d113cce729007558cbb95ebb081d39fa6ebcff Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 17 Jan 2023 23:09:29 +0100 Subject: [PATCH] DiT Pipeline (#1806) * added dit model * import * initial pipeline * initial convert script * initial pipeline * make style * raise valueerror * single function * rename classes * use DDIMScheduler * timesteps embedder * samples to cpu * fix var names * fix numpy type * use timesteps class for proj * fix typo * fix arg name * flip_sin_to_cos and better var names * fix C shape cal * make style * remove unused imports * cleanup * add back patch_size * initial dit doc * typo * Update docs/source/api/pipelines/dit.mdx Co-authored-by: Suraj Patil * added copyright license headers * added example usage and toc * fix variable names asserts * remove comment * added docs * fix typo * upstream changes * set proper device for drop_ids * added initial dit pipeline test * update docs * fix imports * make fix-copies * isort * fix imports * get rid of more magic numbers * fix code when guidance is off * remove block_kwargs * cleanup script * removed to_2tuple * use FeedForward class instead of another MLP * style * work on mergint DiTBlock with BasicTransformerBlock * added missing final_dropout and args to BasicTransformerBlock * use norm from block * fix arg * remove unused arg * fix call to class_embedder * use timesteps * make style * attn_output gets multiplied * removed commented code * use Transformer2D * use self.is_input_patches * fix flags * fixed conversion to use Transformer2DModel * fixes for pipeline * remove dit.py * fix timesteps device * use randn_tensor and fix fp16 inf. * timesteps_emb already the right dtype * fix dit test class * fix test and style * fix norm2 usage in vq-diffusion * added author names to pipeline and lmagenet labels link * fix tests * use norm_type as string * rename dit to transformer * fix name * fix test * set norm_type = "layer" by default * fix tests * do not skip common tests * Update src/diffusers/models/attention.py Co-authored-by: Suraj Patil * revert AdaLayerNorm API * fix norm_type name * make sure all components are in eval mode * revert norm2 API * compact * finish deprecation * add slow tests * remove @ * refactor some stuff * upload * Update src/diffusers/pipelines/dit/pipeline_dit.py * finish more * finish docs * improve docs * finish docs Co-authored-by: Suraj Patil Co-authored-by: William Berman Co-authored-by: Patrick von Platen --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/dit.mdx | 59 ++++++ docs/source/en/api/schedulers/overview.mdx | 3 + scripts/convert_dit_to_diffusers.py | 162 ++++++++++++++ src/diffusers/__init__.py | 1 + src/diffusers/models/attention.py | 102 +++++++-- src/diffusers/models/embeddings.py | 145 +++++++++++++ src/diffusers/models/transformer_2d.py | 81 ++++++- src/diffusers/pipelines/__init__.py | 1 + .../alt_diffusion/pipeline_alt_diffusion.py | 18 +- .../pipeline_alt_diffusion_img2img.py | 18 +- src/diffusers/pipelines/dit/__init__.py | 1 + src/diffusers/pipelines/dit/pipeline_dit.py | 199 ++++++++++++++++++ .../pipeline_stable_diffusion.py | 18 +- .../pipeline_stable_diffusion_depth2img.py | 18 +- ...peline_stable_diffusion_image_variation.py | 18 +- .../pipeline_stable_diffusion_img2img.py | 18 +- .../pipeline_stable_diffusion_inpaint.py | 4 +- ...ipeline_stable_diffusion_inpaint_legacy.py | 18 +- .../pipeline_stable_diffusion_upscale.py | 4 +- .../pipeline_stable_diffusion_safe.py | 18 +- .../pipeline_versatile_diffusion.py | 6 +- ...ipeline_versatile_diffusion_dual_guided.py | 6 +- ...ine_versatile_diffusion_image_variation.py | 6 +- ...eline_versatile_diffusion_text_to_image.py | 6 +- src/diffusers/schedulers/__init__.py | 9 +- src/diffusers/schedulers/scheduling_ddim.py | 6 +- .../schedulers/scheduling_ddim_flax.py | 4 +- src/diffusers/schedulers/scheduling_ddpm.py | 6 +- .../schedulers/scheduling_ddpm_flax.py | 4 +- .../schedulers/scheduling_deis_multistep.py | 5 +- .../scheduling_dpmsolver_multistep.py | 6 +- .../scheduling_dpmsolver_multistep_flax.py | 4 +- .../scheduling_dpmsolver_singlestep.py | 5 +- .../scheduling_euler_ancestral_discrete.py | 6 +- .../schedulers/scheduling_euler_discrete.py | 6 +- .../schedulers/scheduling_heun_discrete.py | 5 +- .../scheduling_k_dpm_2_ancestral_discrete.py | 6 +- .../schedulers/scheduling_k_dpm_2_discrete.py | 5 +- .../schedulers/scheduling_lms_discrete.py | 6 +- .../scheduling_lms_discrete_flax.py | 4 +- src/diffusers/schedulers/scheduling_pndm.py | 5 +- .../schedulers/scheduling_pndm_flax.py | 4 +- src/diffusers/schedulers/scheduling_utils.py | 16 ++ .../schedulers/scheduling_utils_flax.py | 12 +- src/diffusers/utils/__init__.py | 1 - src/diffusers/utils/constants.py | 15 -- src/diffusers/utils/dummy_pt_objects.py | 15 ++ tests/pipelines/dit/__init__.py | 0 tests/pipelines/dit/test_dit.py | 135 ++++++++++++ tests/test_pipelines_common.py | 8 +- 51 files changed, 995 insertions(+), 235 deletions(-) create mode 100644 docs/source/en/api/pipelines/dit.mdx create mode 100644 scripts/convert_dit_to_diffusers.py create mode 100644 src/diffusers/pipelines/dit/__init__.py create mode 100644 src/diffusers/pipelines/dit/pipeline_dit.py create mode 100644 tests/pipelines/dit/__init__.py create mode 100644 tests/pipelines/dit/test_dit.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1be09fdda0..2c0d94fcc1 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -106,6 +106,8 @@ title: DDIM - local: api/pipelines/ddpm title: DDPM + - local: api/pipelines/dit + title: DiT - local: api/pipelines/latent_diffusion title: Latent Diffusion - local: api/pipelines/paint_by_example diff --git a/docs/source/en/api/pipelines/dit.mdx b/docs/source/en/api/pipelines/dit.mdx new file mode 100644 index 0000000000..d7ab18e2ea --- /dev/null +++ b/docs/source/en/api/pipelines/dit.mdx @@ -0,0 +1,59 @@ + + +# [Scalable Diffusion Models with Transformers](https://www.wpeebles.com/DiT) (DiT) + +## Overview + +[Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (DiT) by William Peebles and Saining Xie. + +The abstract of the paper is the following: + +*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.* + +The original codebase of this paper can be found here: [facebookresearch/dit](https://github.com/facebookresearch/dit). + +## Available Pipelines: + +| Pipeline | Tasks | Colab +|---|---|:---:| +| [pipeline_dit.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py) | *Conditional Image Generation* | - | + + +## Usage example + +```python +from diffusers import DiTPipeline, DPMSolverMultistepScheduler +import torch + +pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe = pipe.to("cuda") + +# pick words from Imagenet class labels +pipe.labels # to print all available words + +# pick words that exist in ImageNet +words = ["white shark", "umbrella"] + +class_ids = pipe.get_label_ids(words) + +generator = torch.manual_seed(33) +output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator) + +image = output.images[0] # label 'white shark' +``` + +## DiTPipeline +[[autodoc]] DiTPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/overview.mdx b/docs/source/en/api/schedulers/overview.mdx index 7e139d152b..d27fbe10c5 100644 --- a/docs/source/en/api/schedulers/overview.mdx +++ b/docs/source/en/api/schedulers/overview.mdx @@ -37,6 +37,7 @@ To this end, the design of schedulers is such that: - Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality. - Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists). +- Many diffusion pipelines, such as [`StableDiffusionPipeline`] and [`DiTPipeline`] can use any of [`KarrasDiffusionSchedulers`] ## Schedulers Summary @@ -80,4 +81,6 @@ The class [`SchedulerOutput`] contains the outputs from any schedulers `step(... [[autodoc]] schedulers.scheduling_utils.SchedulerOutput +### KarrasDiffusionSchedulers +[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers diff --git a/scripts/convert_dit_to_diffusers.py b/scripts/convert_dit_to_diffusers.py new file mode 100644 index 0000000000..e14b4ad2a7 --- /dev/null +++ b/scripts/convert_dit_to_diffusers.py @@ -0,0 +1,162 @@ +import argparse +import os + +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel +from torchvision.datasets.utils import download_url + + +pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"} + + +def download_model(model_name): + """ + Downloads a pre-trained DiT model from the web. + """ + local_path = f"pretrained_models/{model_name}" + if not os.path.isfile(local_path): + os.makedirs("pretrained_models", exist_ok=True) + web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}" + download_url(web_path, "pretrained_models") + model = torch.load(local_path, map_location=lambda storage, loc: storage) + return model + + +def main(args): + state_dict = download_model(pretrained_models[args.image_size]) + + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + for depth in range(28): + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[ + "t_embedder.mlp.0.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[ + "t_embedder.mlp.0.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[ + "t_embedder.mlp.2.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[ + "t_embedder.mlp.2.bias" + ] + state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[ + "y_embedder.embedding_table.weight" + ] + + state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.weight" + ] + state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[ + f"blocks.{depth}.adaLN_modulation.1.bias" + ] + + q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0) + + state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q + state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k + state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias + state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v + state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias + + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[ + f"blocks.{depth}.attn.proj.weight" + ] + state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"] + + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"] + state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"] + + state_dict.pop(f"blocks.{depth}.attn.qkv.weight") + state_dict.pop(f"blocks.{depth}.attn.qkv.bias") + state_dict.pop(f"blocks.{depth}.attn.proj.weight") + state_dict.pop(f"blocks.{depth}.attn.proj.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc1.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc1.bias") + state_dict.pop(f"blocks.{depth}.mlp.fc2.weight") + state_dict.pop(f"blocks.{depth}.mlp.fc2.bias") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight") + state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias") + + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.2.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("y_embedder.embedding_table.weight") + + state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"] + state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"] + state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"] + state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"] + + state_dict.pop("final_layer.linear.weight") + state_dict.pop("final_layer.linear.bias") + state_dict.pop("final_layer.adaLN_modulation.1.weight") + state_dict.pop("final_layer.adaLN_modulation.1.bias") + + # DiT XL/2 + transformer = Transformer2DModel( + sample_size=args.image_size // 8, + num_layers=28, + attention_head_dim=72, + in_channels=4, + out_channels=8, + patch_size=2, + attention_bias=True, + num_attention_heads=16, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + ) + transformer.load_state_dict(state_dict, strict=True) + + scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_schedule="linear", + prediction_type="epsilon", + clip_sample=False, + ) + + vae = AutoencoderKL.from_pretrained(args.vae_model) + + pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler) + + if args.save: + pipeline.save_pretrained(args.checkpoint_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--image_size", + default=256, + type=int, + required=False, + help="Image size of pretrained model, either 256 or 512.", + ) + parser.add_argument( + "--vae_model", + default="stabilityai/sd-vae-ft-ema", + type=str, + required=False, + help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.", + ) + parser.add_argument( + "--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline." + ) + + args = parser.parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 90af386f31..4ee671c5aa 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -57,6 +57,7 @@ else: DDIMPipeline, DDPMPipeline, DiffusionPipeline, + DiTPipeline, ImagePipelineOutput, KarrasVePipeline, LDMPipeline, diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 85dcc800fd..08263875d0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -20,6 +20,7 @@ from torch import nn from ..utils.import_utils import is_xformers_available from .cross_attention import CrossAttention +from .embeddings import CombinedTimestepLabelEmbeddings if is_xformers_available(): @@ -196,10 +197,21 @@ class BasicTransformerBlock(nn.Module): attention_bias: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", + final_dropout: bool = False, ): super().__init__() self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm = num_embeds_ada_norm is not None + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) # 1. Self-Attn self.attn1 = CrossAttention( @@ -212,7 +224,7 @@ class BasicTransformerBlock(nn.Module): upcast_attention=upcast_attention, ) - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) # 2. Cross-Attn if cross_attention_dim is not None: @@ -228,15 +240,27 @@ class BasicTransformerBlock(nn.Module): else: self.attn2 = None - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) if cross_attention_dim is not None: - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = ( + AdaLayerNorm(dim, num_embeds_ada_norm) + if self.use_ada_layer_norm + else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) + ) else: self.norm2 = None # 3. Feed-forward - self.norm3 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine) def forward( self, @@ -245,11 +269,18 @@ class BasicTransformerBlock(nn.Module): timestep=None, attention_mask=None, cross_attention_kwargs=None, + class_labels=None, ): + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + else: + norm_hidden_states = self.norm1(hidden_states) + # 1. Self-Attention - norm_hidden_states = ( - self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) - ) cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} attn_output = self.attn1( norm_hidden_states, @@ -257,13 +288,16 @@ class BasicTransformerBlock(nn.Module): attention_mask=attention_mask, **cross_attention_kwargs, ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = attn_output + hidden_states if self.attn2 is not None: - # 2. Cross-Attention norm_hidden_states = ( self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) ) + + # 2. Cross-Attention attn_output = self.attn2( norm_hidden_states, encoder_hidden_states=encoder_hidden_states, @@ -273,7 +307,17 @@ class BasicTransformerBlock(nn.Module): hidden_states = attn_output + hidden_states # 3. Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + ff_output = self.ff(norm_hidden_states) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + + hidden_states = ff_output + hidden_states return hidden_states @@ -288,6 +332,7 @@ class FeedForward(nn.Module): mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. """ def __init__( @@ -297,6 +342,7 @@ class FeedForward(nn.Module): mult: int = 4, dropout: float = 0.0, activation_fn: str = "geglu", + final_dropout: bool = False, ): super().__init__() inner_dim = int(dim * mult) @@ -304,6 +350,8 @@ class FeedForward(nn.Module): if activation_fn == "gelu": act_fn = GELU(dim, inner_dim) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh") elif activation_fn == "geglu": act_fn = GEGLU(dim, inner_dim) elif activation_fn == "geglu-approximate": @@ -316,6 +364,9 @@ class FeedForward(nn.Module): self.net.append(nn.Dropout(dropout)) # project out self.net.append(nn.Linear(inner_dim, dim_out)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) def forward(self, hidden_states): for module in self.net: @@ -325,18 +376,19 @@ class FeedForward(nn.Module): class GELU(nn.Module): r""" - GELU activation function + GELU activation function with tanh approximation support with `approximate="tanh"`. """ - def __init__(self, dim_in: int, dim_out: int): + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): super().__init__() self.proj = nn.Linear(dim_in, dim_out) + self.approximate = approximate def gelu(self, gate): if gate.device.type != "mps": - return F.gelu(gate) + return F.gelu(gate, approximate=self.approximate) # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -344,7 +396,6 @@ class GELU(nn.Module): return hidden_states -# feedforward class GEGLU(nn.Module): r""" A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. @@ -402,3 +453,24 @@ class AdaLayerNorm(nn.Module): scale, shift = torch.chunk(emb, 2) x = self.norm(x) * (1 + scale) + shift return x + + +class AdaLayerNormZero(nn.Module): + """ + Norm layer adaptive layer norm zero (adaLN-Zero). + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + + self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) + + def forward(self, x, timestep, class_labels, hidden_dtype=None): + emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0221d891f1..fc6cae43c1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -61,6 +61,96 @@ def get_timestep_embedding( return emb +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=16, + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + ): + super().__init__() + + num_patches = (height // patch_size) * (width // patch_size) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + return latent + self.pos_embed + + class TimestepEmbedding(nn.Module): def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() @@ -198,3 +288,58 @@ class ImagePositionalEmbeddings(nn.Module): emb = emb + pos_emb[:, : emb.shape[1], :] return emb + + +class LabelEmbedding(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + + Args: + num_classes (`int`): The number of classes. + hidden_size (`int`): The size of the vector embeddings. + dropout_prob (`float`): The probability of dropping a label. + """ + + def __init__(self, num_classes, hidden_size, dropout_prob): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob + else: + drop_ids = torch.tensor(force_drop_ids == 1) + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (self.training and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CombinedTimestepLabelEmbeddings(nn.Module): + def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): + super().__init__() + + self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) + self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) + + def forward(self, timestep, class_labels, hidden_dtype=None): + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + class_labels = self.class_embedder(class_labels) # (N, D) + + conditioning = timesteps_emb + class_labels # (N, D) + + return conditioning diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index def4486932..57dd424aa4 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -20,8 +20,9 @@ from torch import nn from ..configuration_utils import ConfigMixin, register_to_config from ..models.embeddings import ImagePositionalEmbeddings -from ..utils import BaseOutput +from ..utils import BaseOutput, deprecate from .attention import BasicTransformerBlock +from .embeddings import PatchEmbed from .modeling_utils import ModelMixin @@ -81,6 +82,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): num_attention_heads: int = 16, attention_head_dim: int = 88, in_channels: Optional[int] = None, + out_channels: Optional[int] = None, num_layers: int = 1, dropout: float = 0.0, norm_num_groups: int = 32, @@ -88,11 +90,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin): attention_bias: bool = False, sample_size: Optional[int] = None, num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, activation_fn: str = "geglu", num_embeds_ada_norm: Optional[int] = None, use_linear_projection: bool = False, only_cross_attention: bool = False, upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, ): super().__init__() self.use_linear_projection = use_linear_projection @@ -102,18 +107,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # Define whether input is continuous or discrete depending on configuration - self.is_input_continuous = in_channels is not None + self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" if self.is_input_continuous and self.is_input_vectorized: raise ValueError( f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" " sure that either `in_channels` or `num_vector_embeds` is None." ) - elif not self.is_input_continuous and not self.is_input_vectorized: + elif self.is_input_vectorized and self.is_input_patches: raise ValueError( - f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" - " sure that either `in_channels` or `num_vector_embeds` is not None." + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." ) # 2. Define input layers @@ -137,6 +159,20 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self.latent_image_embedding = ImagePositionalEmbeddings( num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) # 3. Define transformers blocks self.transformer_blocks = nn.ModuleList( @@ -152,13 +188,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): attention_bias=attention_bias, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, ) for d in range(num_layers) ] ) # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels if self.is_input_continuous: + # TODO: should use out_channels for continous projections if use_linear_projection: self.proj_out = nn.Linear(in_channels, inner_dim) else: @@ -166,12 +206,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): elif self.is_input_vectorized: self.norm_out = nn.LayerNorm(inner_dim) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) def forward( self, hidden_states, encoder_hidden_states=None, timestep=None, + class_labels=None, cross_attention_kwargs=None, return_dict: bool = True, ): @@ -185,6 +230,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): self-attention. timestep ( `torch.long`, *optional*): Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels + conditioning. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. @@ -195,7 +243,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): """ # 1. Input if self.is_input_continuous: - batch, channel, height, width = hidden_states.shape + batch, _, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) @@ -209,6 +257,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) # 2. Blocks for block in self.transformer_blocks: @@ -217,6 +267,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin): encoder_hidden_states=encoder_hidden_states, timestep=timestep, cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, ) # 3. Output @@ -237,6 +288,24 @@ class Transformer2DModel(ModelMixin, ConfigMixin): # log(p(x_0)) output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) if not return_dict: return (output,) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5b461ba879..f0a6db7123 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -18,6 +18,7 @@ else: from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline + from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion_uncond import LDMPipeline from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py index 5166cbb294..6978ab8e28 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py @@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -91,14 +84,7 @@ class AltDiffusionPipeline(DiffusionPipeline): text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 3bd7b3e75b..67c1d693ef 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -25,14 +25,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -129,14 +122,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline): text_encoder: RobertaSeriesModelWithTransformation, tokenizer: XLMRobertaTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/dit/__init__.py b/src/diffusers/pipelines/dit/__init__.py new file mode 100644 index 0000000000..4ef0729cb4 --- /dev/null +++ b/src/diffusers/pipelines/dit/__init__.py @@ -0,0 +1 @@ +from .pipeline_dit import DiTPipeline diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py new file mode 100644 index 0000000000..ea372036f9 --- /dev/null +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -0,0 +1,199 @@ +# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) +# William Peebles and Saining Xie +# +# Copyright (c) 2021 OpenAI +# MIT License +# +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple, Union + +import torch + +from ...models import AutoencoderKL, Transformer2DModel +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +class DiTPipeline(DiffusionPipeline): + r""" + This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Parameters: + transformer ([`Transformer2DModel`]): + Class conditioned Transformer in Diffusion model to denoise the encoded image latents. + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + scheduler ([`DDIMScheduler`]): + A scheduler to be used in combination with `dit` to denoise the encoded image latents. + """ + + def __init__( + self, + transformer: Transformer2DModel, + vae: AutoencoderKL, + scheduler: KarrasDiffusionSchedulers, + id2label: Optional[Dict[int, str]] = None, + ): + super().__init__() + self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler) + + # create a imagenet -> id dictionary for easier use + self.labels = {} + if id2label is not None: + for key, value in id2label.items(): + for label in value.split(","): + self.labels[label.lstrip().rstrip()] = int(key) + self.labels = dict(sorted(self.labels.items())) + + def get_label_ids(self, label: Union[str, List[str]]) -> List[int]: + r""" + + Map label strings, *e.g.* from ImageNet, to corresponding class ids. + + Parameters: + label (`str` or `dict` of `str`): label strings to be mapped to class ids. + + Returns: + `list` of `int`: Class ids to be processed by pipeline. + """ + + if not isinstance(label, list): + label = list(label) + + for l in label: + if l not in self.labels: + raise ValueError( + f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}." + ) + + return [self.labels[l] for l in label] + + @torch.no_grad() + def __call__( + self, + class_labels: List[int], + guidance_scale: float = 4.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + num_inference_steps: int = 50, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + r""" + Function invoked when calling the pipeline for generation. + + Args: + class_labels (List[int]): + List of imagenet class labels for the images to be generated. + guidance_scale (`float`, *optional*, defaults to 4.0): + Scale of the guidance signal. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + num_inference_steps (`int`, *optional*, defaults to 250): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple. + """ + + batch_size = len(class_labels) + latent_size = self.transformer.config.sample_size + latent_channels = self.transformer.config.in_channels + + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + device=self.device, + dtype=self.transformer.dtype, + ) + latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents + + class_labels = torch.tensor(class_labels, device=self.device).reshape(-1) + class_null = torch.tensor([1000] * batch_size, device=self.device) + class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels + + # set step values + self.scheduler.set_timesteps(num_inference_steps) + + for t in self.progress_bar(self.scheduler.timesteps): + if guidance_scale > 1: + half = latent_model_input[: len(latent_model_input) // 2] + latent_model_input = torch.cat([half, half], dim=0) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + timesteps = t + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(timesteps, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(latent_model_input.shape[0]) + # predict noise model_output + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample + + # perform guidance + if guidance_scale > 1: + eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + + half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + + noise_pred = torch.cat([eps, rest], dim=1) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + model_output, _ = torch.split(noise_pred, latent_channels, dim=1) + else: + model_output = noise_pred + + # compute previous image: x_t -> x_t-1 + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + + if guidance_scale > 1: + latents, _ = latent_model_input.chunk(2, dim=0) + else: + latents = latent_model_input + + latents = 1 / 0.18215 * latents + samples = self.vae.decode(latents).sample + + samples = (samples / 2 + 0.5).clamp(0, 1) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + samples = samples.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + samples = self.numpy_to_pil(samples) + + if not return_dict: + return (samples,) + + return ImagePipelineOutput(images=samples) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c3b4b905e0..b38ca866d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -22,14 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -88,14 +81,7 @@ class StableDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 7e876f49c6..fca9cb9e37 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -25,14 +25,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -91,14 +84,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, depth_estimator: DPTForDepthEstimation, feature_extractor: DPTFeatureExtractor, ): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py index e9ca167707..37d4a50efc 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py @@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -73,14 +66,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline): vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2ec2674840..fceb45e757 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( PIL_INTERPOLATION, deprecate, @@ -133,14 +126,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1eb7109375..140aa8da2a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -173,7 +173,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py index 1f0be3ac0b..588682b4ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py @@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionPipelineOutput @@ -100,14 +93,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index d5eb63ca5d..af4caa3202 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -22,7 +22,7 @@ import PIL from transformers import CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -84,7 +84,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, low_res_scheduler: DDPMScheduler, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, max_noise_level: int = 350, ): super().__init__() diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index b2bed9d208..ff4b41a9dc 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, -) +from ...schedulers import KarrasDiffusionSchedulers from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionSafePipelineOutput @@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[ - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - ], + scheduler: KarrasDiffusionSchedulers, safety_checker: SafeStableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, requires_safety_checker: bool = True, diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py index 88e7e4b6a4..ec8be907bb 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion.py @@ -7,7 +7,7 @@ import PIL.Image from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import logging from ..pipeline_utils import DiffusionPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline @@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNet2DConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers def __init__( self, @@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py index 71bfe56b03..4602448542 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py @@ -28,7 +28,7 @@ from transformers import ( ) from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel @@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] @@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py index 56b532010c..b08d9bb143 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py @@ -23,7 +23,7 @@ import PIL from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from ...models import AutoencoderKL, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput @@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): image_encoder: CLIPVisionModelWithProjection image_unet: UNet2DConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers def __init__( self, @@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): image_encoder: CLIPVisionModelWithProjection, image_unet: UNet2DConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py index ac0adf5e7a..06d8773eaf 100644 --- a/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +++ b/src/diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py @@ -21,7 +21,7 @@ import torch.utils.checkpoint from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel -from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_accelerate_available, logging, randn_tensor from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from .modeling_text_unet import UNetFlatConditionModel @@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): image_unet: UNet2DConditionModel text_unet: UNetFlatConditionModel vae: AutoencoderKL - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] + scheduler: KarrasDiffusionSchedulers _optional_components = ["text_unet"] @@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): image_unet: UNet2DConditionModel, text_unet: UNetFlatConditionModel, vae: AutoencoderKL, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: KarrasDiffusionSchedulers, ): super().__init__() self.register_modules( diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 298bbb9ef4..3746acd5b5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,7 +39,7 @@ else: from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_unclip import UnCLIPScheduler - from .scheduling_utils import SchedulerMixin + from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin from .scheduling_vq_diffusion import VQDiffusionScheduler try: @@ -55,7 +55,12 @@ else: from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler - from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left + from .scheduling_utils_flax import ( + FlaxKarrasDiffusionSchedulers, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, + ) try: diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 32a16071c6..6a9fe29c62 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,8 +23,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 9c675f1754..52a997fa98 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -24,8 +24,8 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 12c17fd169..b58ed83382 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,8 +22,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, deprecate, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index ed83ae8df2..8223b340cb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -24,8 +24,8 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py index 528bef9e09..1ad5480b78 100644 --- a/src/diffusers/schedulers/scheduling_deis_multistep.py +++ b/src/diffusers/schedulers/scheduling_deis_multistep.py @@ -22,8 +22,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 2a003920ec..8acb87d78a 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,8 +21,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import deprecate +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] order = 1 diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index 0aa121b59d..ed2ed5f5e5 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -24,8 +24,8 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from ..utils import deprecate from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] _deprecated_kwargs = ["predict_epsilon"] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py index d016711b59..0225d8027b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py @@ -21,8 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 9976235b75..45f939aafe 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -19,8 +19,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10f277f7e0..02e5c2cd99 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -19,8 +19,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput, logging, randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_heun_discrete.py b/src/diffusers/schedulers/scheduling_heun_discrete.py index 4f40a24050..0dea944b6f 100644 --- a/src/diffusers/schedulers/scheduling_heun_discrete.py +++ b/src/diffusers/schedulers/scheduling_heun_discrete.py @@ -18,8 +18,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py index 370a078704..175f338b92 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py @@ -18,8 +18,8 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from ..utils import randn_tensor +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py index 8aee346c57..18dd976716 100644 --- a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py +++ b/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py @@ -18,8 +18,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 2 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 28bc9bd0c6..f2c474ffe1 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,8 +21,8 @@ import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @dataclass @@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): https://imagen.research.google/video/paper.pdf) """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index fde18f2653..e105ded997 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -21,8 +21,8 @@ from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left, @@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index c3ac5fdf75..065a07e955 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,8 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS -from .scheduling_utils import SchedulerMixin, SchedulerOutput +from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -92,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ - _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in KarrasDiffusionSchedulers] order = 1 @register_to_config diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 25c0db9346..572da53464 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,8 +23,8 @@ import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config from .scheduling_utils_flax import ( - _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CommonSchedulerState, + FlaxKarrasDiffusionSchedulers, FlaxSchedulerMixin, FlaxSchedulerOutput, add_noise_common, @@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): the `dtype` used for params and computation. """ - _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers] dtype: jnp.dtype pndm_order: int diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 90ab674e38..f4103d4d62 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -14,6 +14,7 @@ import importlib import os from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, Optional, Union import torch @@ -24,6 +25,21 @@ from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" +class KarrasDiffusionSchedulers(Enum): + DDIMScheduler = 1 + DDPMScheduler = 2 + PNDMScheduler = 3 + LMSDiscreteScheduler = 4 + EulerDiscreteScheduler = 5 + HeunDiscreteScheduler = 6 + EulerAncestralDiscreteScheduler = 7 + DPMSolverMultistepScheduler = 8 + DPMSolverSinglestepScheduler = 9 + KDPM2DiscreteScheduler = 10 + KDPM2AncestralDiscreteScheduler = 11 + DEISMultistepScheduler = 12 + + @dataclass class SchedulerOutput(BaseOutput): """ diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index 889c0f25bc..9708c08837 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -15,16 +15,24 @@ import importlib import math import os from dataclasses import dataclass +from enum import Enum from typing import Any, Dict, Optional, Tuple, Union import flax import jax.numpy as jnp -from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput +from ..utils import BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" -_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS] + + +class FlaxKarrasDiffusionSchedulers(Enum): + FlaxDDIMScheduler = 1 + FlaxDDPMScheduler = 2 + FlaxPNDMScheduler = 3 + FlaxLMSDiscreteScheduler = 4 + FlaxDPMSolverMultistepScheduler = 5 @dataclass diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 4f2bc27bef..ece9c6d3f9 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -19,7 +19,6 @@ from packaging import version from .. import __version__ from .constants import ( - _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, CONFIG_NAME, DIFFUSERS_CACHE, DIFFUSERS_DYNAMIC_MODULE_NAME, diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 35efff392c..0edb4c57f0 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -30,18 +30,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) - -_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "HeunDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - "DPMSolverSinglestepScheduler", - "KDPM2DiscreteScheduler", - "KDPM2AncestralDiscreteScheduler", - "DEISMultistepScheduler", -] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 62c2bbc273..1e7c0a46a2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -227,6 +227,21 @@ class DiffusionPipeline(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class DiTPipeline(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class ImagePipelineOutput(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/pipelines/dit/__init__.py b/tests/pipelines/dit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/pipelines/dit/test_dit.py b/tests/pipelines/dit/test_dit.py new file mode 100644 index 0000000000..ab41f9751c --- /dev/null +++ b/tests/pipelines/dit/test_dit.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import unittest + +import numpy as np +import torch + +from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel +from diffusers.utils import load_numpy, slow +from diffusers.utils.testing_utils import require_torch_gpu + +from ...test_pipelines_common import PipelineTesterMixin + + +torch.backends.cuda.matmul.allow_tf32 = False + + +class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = DiTPipeline + test_cpu_offload = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = Transformer2DModel( + sample_size=4, + num_layers=2, + patch_size=2, + attention_head_dim=2, + num_attention_heads=2, + in_channels=4, + out_channels=8, + attention_bias=True, + activation_fn="gelu-approximate", + num_embeds_ada_norm=1000, + norm_type="ada_norm_zero", + norm_elementwise_affine=False, + ) + vae = AutoencoderKL() + scheduler = DDIMScheduler() + components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler} + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "class_labels": [1], + "generator": generator, + "num_inference_steps": 2, + "output_type": "numpy", + } + return inputs + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 4, 4, 3)) + expected_slice = np.array( + [0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058] + ) + max_diff = np.abs(image_slice.flatten() - expected_slice).max() + self.assertLessEqual(max_diff, 1e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(relax_max_difference=True) + + +@require_torch_gpu +@slow +class DiTPipelineIntegrationTests(unittest.TestCase): + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_dit_256(self): + generator = torch.manual_seed(0) + + pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256") + pipe.to("cuda") + + words = ["vase", "umbrella", "white shark", "white wolf"] + ids = pipe.get_label_ids(words) + + images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images + + for word, image in zip(words, images): + expected_image = load_numpy( + f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy" + ) + assert np.abs((expected_image - image).sum()) < 1e-3 + + def test_dit_512_fp16(self): + generator = torch.manual_seed(0) + + pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.to("cuda") + + words = ["vase", "umbrella", "white shark", "white wolf"] + ids = pipe.get_label_ids(words) + + images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images + + for word, image in zip(words, images): + expected_image = load_numpy( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" + f"/dit/{word}_fp16.npy" + ) + assert np.abs((expected_image - image).sum()) < 1e-3 diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 7babec5888..08f13b8960 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -36,7 +36,7 @@ class PipelineTesterMixin: equivalence of dict and tuple outputs, etc. """ - allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"] + allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image", "class_labels"] required_optional_params = ["generator", "num_inference_steps", "return_dict"] num_inference_steps_args = ["num_inference_steps"] @@ -194,8 +194,8 @@ class PipelineTesterMixin: ): if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]: # RePaint can hardly be made deterministic since the scheduler is currently always - # indeterministic - # CycleDiffusion is also slighly undeterministic + # nondeterministic + # CycleDiffusion is also slightly nondeterministic return if test_max_difference is None: @@ -515,7 +515,7 @@ class PipelineTesterMixin: torch_device != "cuda" or not is_xformers_available(), reason="XFormers attention is only available with CUDA and `xformers` installed", ) - def test_xformers_attention_forward_pass(self): + def test_xformers_attention_forwardGenerator_pass(self): if not self.test_xformers_attention: return