diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e904067b31..72808df049 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -132,8 +132,6 @@ title: Conceptual Guides - sections: - sections: - - local: api/models - title: Models - local: api/attnprocessor title: Attention Processor - local: api/diffusion_pipeline @@ -151,6 +149,30 @@ - local: api/image_processor title: VAE Image Processor title: Main Classes + - sections: + - local: api/models/overview + title: Overview + - local: api/models/unet + title: UNet1DModel + - local: api/models/unet2d + title: UNet2DModel + - local: api/models/unet2d-cond + title: UNet2DConditionModel + - local: api/models/unet3d-cond + title: UNet3DConditionModel + - local: api/models/vq + title: VQModel + - local: api/models/autoencoderkl + title: AutoencoderKL + - local: api/models/transformer2d + title: Transformer2D + - local: api/models/transformer_temporal + title: Transformer Temporal + - local: api/models/prior_transformer + title: Prior Transformer + - local: api/models/controlnet + title: ControlNet + title: Models - sections: - local: api/pipelines/overview title: Overview diff --git a/docs/source/en/api/models.mdx b/docs/source/en/api/models.mdx deleted file mode 100644 index 74291f9173..0000000000 --- a/docs/source/en/api/models.mdx +++ /dev/null @@ -1,107 +0,0 @@ - - -# Models - -Diffusers contains pretrained models for popular algorithms and modules for creating the next set of diffusion models. -The primary function of these models is to denoise an input sample, by modeling the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\). -The models are built on the base class ['ModelMixin'] that is a `torch.nn.module` with basic functionality for saving and loading models both locally and from the HuggingFace hub. - -## ModelMixin -[[autodoc]] ModelMixin - -## UNet2DOutput -[[autodoc]] models.unet_2d.UNet2DOutput - -## UNet2DModel -[[autodoc]] UNet2DModel - -## UNet1DOutput -[[autodoc]] models.unet_1d.UNet1DOutput - -## UNet1DModel -[[autodoc]] UNet1DModel - -## UNet2DConditionOutput -[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput - -## UNet2DConditionModel -[[autodoc]] UNet2DConditionModel - -## UNet3DConditionOutput -[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput - -## UNet3DConditionModel -[[autodoc]] UNet3DConditionModel - -## DecoderOutput -[[autodoc]] models.vae.DecoderOutput - -## VQEncoderOutput -[[autodoc]] models.vq_model.VQEncoderOutput - -## VQModel -[[autodoc]] VQModel - -## AutoencoderKLOutput -[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput - -## AutoencoderKL -[[autodoc]] AutoencoderKL - -## Transformer2DModel -[[autodoc]] Transformer2DModel - -## Transformer2DModelOutput -[[autodoc]] models.transformer_2d.Transformer2DModelOutput - -## TransformerTemporalModel -[[autodoc]] models.transformer_temporal.TransformerTemporalModel - -## Transformer2DModelOutput -[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput - -## PriorTransformer -[[autodoc]] models.prior_transformer.PriorTransformer - -## PriorTransformerOutput -[[autodoc]] models.prior_transformer.PriorTransformerOutput - -## ControlNetOutput -[[autodoc]] models.controlnet.ControlNetOutput - -## ControlNetModel -[[autodoc]] ControlNetModel - -## FlaxModelMixin -[[autodoc]] FlaxModelMixin - -## FlaxUNet2DConditionOutput -[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput - -## FlaxUNet2DConditionModel -[[autodoc]] FlaxUNet2DConditionModel - -## FlaxDecoderOutput -[[autodoc]] models.vae_flax.FlaxDecoderOutput - -## FlaxAutoencoderKLOutput -[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput - -## FlaxAutoencoderKL -[[autodoc]] FlaxAutoencoderKL - -## FlaxControlNetOutput -[[autodoc]] models.controlnet_flax.FlaxControlNetOutput - -## FlaxControlNetModel -[[autodoc]] FlaxControlNetModel diff --git a/docs/source/en/api/models/autoencoderkl.mdx b/docs/source/en/api/models/autoencoderkl.mdx new file mode 100644 index 0000000000..542fc27cd5 --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl.mdx @@ -0,0 +1,31 @@ +# AutoencoderKL + +The variational autoencoder (VAE) model with KL loss was introduced in [Auto-Encoding Variational Bayes](https://arxiv.org/abs/1312.6114v11) by Diederik P. Kingma and Max Welling. The model is used in 🤗 Diffusers to encode images into latents and to decode latent representations into images. + +The abstract from the paper is: + +*How can we perform efficient inference and learning in directed probabilistic models, in the presence of continuous latent variables with intractable posterior distributions, and large datasets? We introduce a stochastic variational inference and learning algorithm that scales to large datasets and, under some mild differentiability conditions, even works in the intractable case. Our contributions are two-fold. First, we show that a reparameterization of the variational lower bound yields a lower bound estimator that can be straightforwardly optimized using standard stochastic gradient methods. Second, we show that for i.i.d. datasets with continuous latent variables per datapoint, posterior inference can be made especially efficient by fitting an approximate inference model (also called a recognition model) to the intractable posterior using the proposed lower bound estimator. Theoretical advantages are reflected in experimental results.* + +## AutoencoderKL + +[[autodoc]] AutoencoderKL + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.vae.DecoderOutput + +## FlaxAutoencoderKL + +[[autodoc]] FlaxAutoencoderKL + +## FlaxAutoencoderKLOutput + +[[autodoc]] models.vae_flax.FlaxAutoencoderKLOutput + +## FlaxDecoderOutput + +[[autodoc]] models.vae_flax.FlaxDecoderOutput \ No newline at end of file diff --git a/docs/source/en/api/models/controlnet.mdx b/docs/source/en/api/models/controlnet.mdx new file mode 100644 index 0000000000..ae2d06edbb --- /dev/null +++ b/docs/source/en/api/models/controlnet.mdx @@ -0,0 +1,23 @@ +# ControlNet + +The ControlNet model was introduced in [Adding Conditional Control to Text-to-Image Diffusion Models](https://huggingface.co/papers/2302.05543) by Lvmin Zhang and Maneesh Agrawala. It provides a greater degree of control over text-to-image generation by conditioning the model on additional inputs such as edge maps, depth maps, segmentation maps, and keypoints for pose detection. + +The abstract from the paper is: + +*We present a neural network structure, ControlNet, to control pretrained large diffusion models to support additional input conditions. The ControlNet learns task-specific conditions in an end-to-end way, and the learning is robust even when the training dataset is small (< 50k). Moreover, training a ControlNet is as fast as fine-tuning a diffusion model, and the model can be trained on a personal devices. Alternatively, if powerful computation clusters are available, the model can scale to large amounts (millions to billions) of data. We report that large diffusion models like Stable Diffusion can be augmented with ControlNets to enable conditional inputs like edge maps, segmentation maps, keypoints, etc. This may enrich the methods to control large diffusion models and further facilitate related applications.* + +## ControlNetModel + +[[autodoc]] ControlNetModel + +## ControlNetOutput + +[[autodoc]] models.controlnet.ControlNetOutput + +## FlaxControlNetModel + +[[autodoc]] FlaxControlNetModel + +## FlaxControlNetOutput + +[[autodoc]] models.controlnet_flax.FlaxControlNetOutput \ No newline at end of file diff --git a/docs/source/en/api/models/overview.mdx b/docs/source/en/api/models/overview.mdx new file mode 100644 index 0000000000..cc94861fba --- /dev/null +++ b/docs/source/en/api/models/overview.mdx @@ -0,0 +1,12 @@ +# Models + +🤗 Diffusers provides pretrained models for popular algorithms and modules to create custom diffusion systems. The primary function of models is to denoise an input sample as modeled by the distribution \\(p_{\theta}(x_{t-1}|x_{t})\\). + +All models are built from the base [`ModelMixin`] class which is a [`torch.nn.module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html) providing basic functionality for saving and loading models, locally and from the Hugging Face Hub. + +## ModelMixin +[[autodoc]] ModelMixin + +## FlaxModelMixin + +[[autodoc]] FlaxModelMixin \ No newline at end of file diff --git a/docs/source/en/api/models/prior_transformer.mdx b/docs/source/en/api/models/prior_transformer.mdx new file mode 100644 index 0000000000..1d2b799ed3 --- /dev/null +++ b/docs/source/en/api/models/prior_transformer.mdx @@ -0,0 +1,16 @@ +# Prior Transformer + +The Prior Transformer was originally introduced in [Hierarchical Text-Conditional Image Generation with CLIP Latents +](https://huggingface.co/papers/2204.06125) by Ramesh et al. It is used to predict CLIP image embeddings from CLIP text embeddings; image embeddings are predicted through a denoising diffusion process. + +The abstract from the paper is: + +*Contrastive models like CLIP have been shown to learn robust representations of images that capture both semantics and style. To leverage these representations for image generation, we propose a two-stage model: a prior that generates a CLIP image embedding given a text caption, and a decoder that generates an image conditioned on the image embedding. We show that explicitly generating image representations improves image diversity with minimal loss in photorealism and caption similarity. Our decoders conditioned on image representations can also produce variations of an image that preserve both its semantics and style, while varying the non-essential details absent from the image representation. Moreover, the joint embedding space of CLIP enables language-guided image manipulations in a zero-shot fashion. We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher-quality samples.* + +## PriorTransformer + +[[autodoc]] PriorTransformer + +## PriorTransformerOutput + +[[autodoc]] models.prior_transformer.PriorTransformerOutput \ No newline at end of file diff --git a/docs/source/en/api/models/transformer2d.mdx b/docs/source/en/api/models/transformer2d.mdx new file mode 100644 index 0000000000..4ad2b00b6f --- /dev/null +++ b/docs/source/en/api/models/transformer2d.mdx @@ -0,0 +1,29 @@ +# Transformer2D + +A Transformer model for image-like data from [CompVis](https://huggingface.co/CompVis) that is based on the [Vision Transformer](https://huggingface.co/papers/2010.11929) introduced by Dosovitskiy et al. The [`Transformer2DModel`] accepts discrete (classes of vector embeddings) or continuous (actual embeddings) inputs. + +When the input is **continuous**: + +1. Project the input and reshape it to `(batch_size, sequence_length, feature_dimension)`. +2. Apply the Transformer blocks in the standard way. +3. Reshape to image. + +When the input is **discrete**: + + + +It is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised image don't contain a prediction for the masked pixel because the unnoised image cannot be masked. + + + +1. Convert input (classes of latent pixels) to embeddings and apply positional embeddings. +2. Apply the Transformer blocks in the standard way. +3. Predict classes of unnoised image. + +## Transformer2DModel + +[[autodoc]] Transformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.transformer_2d.Transformer2DModelOutput diff --git a/docs/source/en/api/models/transformer_temporal.mdx b/docs/source/en/api/models/transformer_temporal.mdx new file mode 100644 index 0000000000..d67cf717f9 --- /dev/null +++ b/docs/source/en/api/models/transformer_temporal.mdx @@ -0,0 +1,11 @@ +# Transformer Temporal + +A Transformer model for video-like data. + +## TransformerTemporalModel + +[[autodoc]] models.transformer_temporal.TransformerTemporalModel + +## TransformerTemporalModelOutput + +[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput \ No newline at end of file diff --git a/docs/source/en/api/models/unet.mdx b/docs/source/en/api/models/unet.mdx new file mode 100644 index 0000000000..9a488a3231 --- /dev/null +++ b/docs/source/en/api/models/unet.mdx @@ -0,0 +1,13 @@ +# UNet1DModel + +The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 1D UNet model. + +The abstract from the paper is: + +*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.* + +## UNet1DModel +[[autodoc]] UNet1DModel + +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput \ No newline at end of file diff --git a/docs/source/en/api/models/unet2d-cond.mdx b/docs/source/en/api/models/unet2d-cond.mdx new file mode 100644 index 0000000000..a669b02a7f --- /dev/null +++ b/docs/source/en/api/models/unet2d-cond.mdx @@ -0,0 +1,19 @@ +# UNet2DConditionModel + +The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet conditional model. + +The abstract from the paper is: + +*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.* + +## UNet2DConditionModel +[[autodoc]] UNet2DConditionModel + +## UNet2DConditionOutput +[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput + +## FlaxUNet2DConditionModel +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionModel + +## FlaxUNet2DConditionOutput +[[autodoc]] models.unet_2d_condition_flax.FlaxUNet2DConditionOutput \ No newline at end of file diff --git a/docs/source/en/api/models/unet2d.mdx b/docs/source/en/api/models/unet2d.mdx new file mode 100644 index 0000000000..29e8163f64 --- /dev/null +++ b/docs/source/en/api/models/unet2d.mdx @@ -0,0 +1,13 @@ +# UNet2DModel + +The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 2D UNet model. + +The abstract from the paper is: + +*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.* + +## UNet2DModel +[[autodoc]] UNet2DModel + +## UNet2DOutput +[[autodoc]] models.unet_2d.UNet2DOutput \ No newline at end of file diff --git a/docs/source/en/api/models/unet3d-cond.mdx b/docs/source/en/api/models/unet3d-cond.mdx new file mode 100644 index 0000000000..83dbb514c8 --- /dev/null +++ b/docs/source/en/api/models/unet3d-cond.mdx @@ -0,0 +1,13 @@ +# UNet3DConditionModel + +The [UNet](https://huggingface.co/papers/1505.04597) model was originally introduced by Ronneberger et al for biomedical image segmentation, but it is also commonly used in 🤗 Diffusers because it outputs images that are the same size as the input. It is one of the most important components of a diffusion system because it facilitates the actual diffusion process. There are several variants of the UNet model in 🤗 Diffusers, depending on it's number of dimensions and whether it is a conditional model or not. This is a 3D UNet conditional model. + +The abstract from the paper is: + +*There is large consent that successful training of deep networks requires many thousand annotated training samples. In this paper, we present a network and training strategy that relies on the strong use of data augmentation to use the available annotated samples more efficiently. The architecture consists of a contracting path to capture context and a symmetric expanding path that enables precise localization. We show that such a network can be trained end-to-end from very few images and outperforms the prior best method (a sliding-window convolutional network) on the ISBI challenge for segmentation of neuronal structures in electron microscopic stacks. Using the same network trained on transmitted light microscopy images (phase contrast and DIC) we won the ISBI cell tracking challenge 2015 in these categories by a large margin. Moreover, the network is fast. Segmentation of a 512x512 image takes less than a second on a recent GPU. The full implementation (based on Caffe) and the trained networks are available at http://lmb.informatik.uni-freiburg.de/people/ronneber/u-net.* + +## UNet3DConditionModel +[[autodoc]] UNet3DConditionModel + +## UNet3DConditionOutput +[[autodoc]] models.unet_3d_condition.UNet3DConditionOutput \ No newline at end of file diff --git a/docs/source/en/api/models/vq.mdx b/docs/source/en/api/models/vq.mdx new file mode 100644 index 0000000000..cdb6761468 --- /dev/null +++ b/docs/source/en/api/models/vq.mdx @@ -0,0 +1,15 @@ +# VQModel + +The VQ-VAE model was introduced in [Neural Discrete Representation Learning](https://huggingface.co/papers/1711.00937) by Aaron van den Oord, Oriol Vinyals and Koray Kavukcuoglu. The model is used in 🤗 Diffusers to decode latent representations into images. Unlike [`AutoencoderKL`], the [`VQModel`] works in a quantized latent space. + +The abstract from the paper is: + +*Learning useful representations without supervision remains a key challenge in machine learning. In this paper, we propose a simple yet powerful generative model that learns such discrete representations. Our model, the Vector Quantised-Variational AutoEncoder (VQ-VAE), differs from VAEs in two key ways: the encoder network outputs discrete, rather than continuous, codes; and the prior is learnt rather than static. In order to learn a discrete latent representation, we incorporate ideas from vector quantisation (VQ). Using the VQ method allows the model to circumvent issues of "posterior collapse" -- where the latents are ignored when they are paired with a powerful autoregressive decoder -- typically observed in the VAE framework. Pairing these representations with an autoregressive prior, the model can generate high quality images, videos, and speech as well as doing high quality speaker conversion and unsupervised learning of phonemes, providing further evidence of the utility of the learnt representations.* + +## VQModel + +[[autodoc]] VQModel + +## VQEncoderOutput + +[[autodoc]] models.vq_model.VQEncoderOutput \ No newline at end of file diff --git a/src/diffusers/models/autoencoder_kl.py b/src/diffusers/models/autoencoder_kl.py index 7178543132..d61281a53e 100644 --- a/src/diffusers/models/autoencoder_kl.py +++ b/src/diffusers/models/autoencoder_kl.py @@ -39,24 +39,24 @@ class AutoencoderKLOutput(BaseOutput): class AutoencoderKL(ModelMixin, ConfigMixin): - r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma - and Max Welling. + r""" + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO + sample_size (`int`, *optional*, defaults to `32`): Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion @@ -131,15 +131,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): def enable_tiling(self, use_tiling: bool = True): r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to - compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow - the processing of larger images. + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. """ self.use_tiling = use_tiling def disable_tiling(self): r""" - Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to - computing decoding in one step. + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. """ self.enable_tiling(False) @@ -152,7 +152,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): def disable_slicing(self): r""" - Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing decoding in one step. """ self.use_slicing = False @@ -185,11 +185,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin): # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -274,14 +278,21 @@ class AutoencoderKL(ModelMixin, ConfigMixin): def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: r"""Encode a batch of images using a tiled encoder. - Args: When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: - different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is + different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the - look of the output, but they should be much less noticeable. - x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + output, but they should be much less noticeable. + + Args: + x (`torch.FloatTensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`: + If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain + `tuple` is returned. """ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) @@ -319,17 +330,18 @@ class AutoencoderKL(ModelMixin, ConfigMixin): return AutoencoderKLOutput(latent_dist=posterior) def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: - r"""Decode a batch of images using a tiled decoder. + r""" + Decode a batch of images using a tiled decoder. Args: - When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several - steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: - different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the - tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the - look of the output, but they should be much less noticeable. - z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to - `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + z (`torch.FloatTensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. """ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index 8660c3f9a5..b0f5660200 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -37,6 +37,20 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class ControlNetOutput(BaseOutput): + """ + The output of [`ControlNetModel`]. + + Args: + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + down_block_res_samples: Tuple[torch.Tensor] mid_block_res_sample: torch.Tensor @@ -87,6 +101,58 @@ class ControlNetConditioningEmbedding(nn.Module): class ControlNetModel(ModelMixin, ConfigMixin): + """ + A ControlNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + controlnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + """ + _supports_gradient_checkpointing = True @register_to_config @@ -283,12 +349,12 @@ class ControlNetModel(ModelMixin, ConfigMixin): load_weights_from_unet: bool = True, ): r""" - Instantiate Controlnet class from UNet2DConditionModel. + Instantiate a [`ControlNetModel`] from [`UNet2DConditionModel`]. Parameters: unet (`UNet2DConditionModel`): - UNet model which weights are copied to the ControlNet. Note that all configuration options are also - copied where applicable. + The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied + where applicable. """ controlnet = cls( in_channels=unet.config.in_channels, @@ -357,11 +423,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -397,15 +467,15 @@ class ControlNetModel(ModelMixin, ConfigMixin): r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `num_attention_heads // slice_size`. In this case, - `num_attention_heads` must be a multiple of `slice_size`. + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ sliceable_head_dims = [] @@ -476,6 +546,37 @@ class ControlNetModel(ModelMixin, ConfigMixin): guess_mode: bool = False, return_dict: bool = True, ) -> Union[ControlNetOutput, Tuple]: + """ + The [`ControlNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + controlnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for ControlNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. + + Returns: + [`~models.controlnet.ControlNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ # check channel order channel_order = self.config.controlnet_conditioning_channel_order diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index cff451edcd..a826df48e4 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -32,6 +32,14 @@ from .unet_2d_blocks_flax import ( @flax.struct.dataclass class FlaxControlNetOutput(BaseOutput): + """ + The output of [`FlaxControlNetModel`]. + + Args: + down_block_res_samples (`jnp.ndarray`): + mid_block_res_sample (`jnp.ndarray`): + """ + down_block_res_samples: jnp.ndarray mid_block_res_sample: jnp.ndarray @@ -95,21 +103,17 @@ class FlaxControlNetConditioningEmbedding(nn.Module): @flax_register_to_config class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN - [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized - training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the - convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides - (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full - model) to encode image-space conditions ... into feature maps ..." + A ControlNet model. - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it’s generic methods + implemented for all models (such as downloading or saving). - Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + This model is also a Flax Linen [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. - Finally, this model supports inherent JAX features such as: + Inherent JAX features such as the following are supported: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) @@ -120,9 +124,8 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): The size of the input sample. in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", - "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): @@ -139,11 +142,9 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. controlnet_conditioning_channel_order (`str`, *optional*, defaults to `rgb`): - The channel order of conditional image. Will convert it to `rgb` if it's `bgr` + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. conditioning_embedding_out_channels (`tuple`, *optional*, defaults to `(16, 32, 96, 256)`): - The tuple of output channel for each block in conditioning_embedding layer - - + The tuple of output channel for each block in the `conditioning_embedding` layer. """ sample_size: int = 32 in_channels: int = 4 diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 58c492a974..9a6e1b3bba 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -44,10 +44,12 @@ logger = logging.get_logger(__name__) class FlaxModelMixin: r""" - Base class for all flax models. + Base class for all Flax models. - [`FlaxModelMixin`] takes care of storing the configuration of the models and handles methods for loading, - downloading and saving models. + [`FlaxModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. + + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~FlaxModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @@ -89,15 +91,15 @@ class FlaxModelMixin: Cast the floating-point `params` to `jax.numpy.bfloat16`. This returns a new `params` tree and does not cast the `params` in place. - This method can be used on TPU to explicitly convert the model parameters to bfloat16 precision to do full + This method can be used on a TPU to explicitly convert the model parameters to bfloat16 precision to do full half-precision training or to save weights in bfloat16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip. + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. Examples: @@ -132,8 +134,8 @@ class FlaxModelMixin: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. Examples: @@ -155,15 +157,15 @@ class FlaxModelMixin: Cast the floating-point `params` to `jax.numpy.float16`. This returns a new `params` tree and does not cast the `params` in place. - This method can be used on GPU to explicitly convert the model parameters to float16 precision to do full + This method can be used on a GPU to explicitly convert the model parameters to float16 precision to do full half-precision training or to save weights in float16 for inference in order to save memory and improve speed. Arguments: params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. mask (`Union[Dict, FrozenDict]`): - A `PyTree` with same structure as the `params` tree. The leaves should be booleans, `True` for params - you want to cast, and should be `False` for those you want to skip + A `PyTree` with same structure as the `params` tree. The leaves should be booleans. It should be `True` + for params you want to cast, and `False` for those you want to skip. Examples: @@ -201,71 +203,68 @@ class FlaxModelMixin: **kwargs, ): r""" - Instantiate a pretrained flax model from a pre-trained model configuration. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. + Instantiate a pretrained Flax model from a pretrained model configuration. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`): Can be either: - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids are namespaced under a user or organization name, like - `runwayml/stable-diffusion-v1-5`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_pretrained`], - e.g., `./my_model_directory/`. + - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model + hosted on the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + using [`~FlaxModelMixin.save_pretrained`]. dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and `jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If - specified all the computation will be performed with the given `dtype`. + specified, all the computation will be performed with the given `dtype`. - **Note that this only specifies the dtype of the computation and does not influence the dtype of model - parameters.** + + + This only specifies the dtype of the *computation* and does not influence the dtype of model + parameters. + + If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and + [`~FlaxModelMixin.to_bf16`]. + + - If you wish to change the dtype of the model parameters, see [`~ModelMixin.to_fp16`] and - [`~ModelMixin.to_bf16`]. model_args (sequence of positional arguments, *optional*): - All remaining positional arguments will be passed to the underlying model's `__init__` method. + All remaining positional arguments are passed to the underlying model's `__init__` method. cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. from_pt (`bool`, *optional*, defaults to `False`): Load the model weights from a PyTorch checkpoint save file. kwargs (remaining dictionary of keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + Can be used to update the configuration object (after it is loaded) and initiate the model (for + example, `output_attentions=True`). Behaves differently depending on whether a `config` is provided or automatically loaded: - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~ConfigMixin.from_config`]). Each key of `kwargs` that corresponds to - a configuration attribute will be used to override said attribute with the supplied `kwargs` - value. Remaining keys that do not correspond to any configuration attribute will be passed to the - underlying model's `__init__` function. + - If a configuration is provided with `config`, `kwargs` are directly passed to the underlying + model's `__init__` method (we assume all relevant updates to the configuration have already been + done). + - If a configuration is not provided, `kwargs` are first passed to the configuration class + initialization function [`~ConfigMixin.from_config`]. Each key of the `kwargs` that corresponds + to a configuration attribute is used to override said attribute with the supplied `kwargs` value. + Remaining keys that do not correspond to any configuration attribute are passed to the underlying + model's `__init__` function. Examples: @@ -276,7 +275,16 @@ class FlaxModelMixin: >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5") >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). >>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/") - ```""" + ``` + + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` + """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) @@ -491,18 +499,18 @@ class FlaxModelMixin: is_main_process: bool = True, ): """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~FlaxModelMixin.from_pretrained`]` class method + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~FlaxModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. + Directory to save a model and its configuration file to. Will be created if it doesn't exist. params (`Union[Dict, FrozenDict]`): A `PyTree` of model parameters. is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. """ if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 9e9b5cde46..cc8df3fe6d 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -154,11 +154,10 @@ class ModelMixin(torch.nn.Module): r""" Base class for all models. - [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading - and saving models. + [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and + saving models. - - **config_name** ([`str`]) -- A filename under which the model should be stored when calling - [`~models.ModelMixin.save_pretrained`]. + - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`]. """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] @@ -190,18 +189,13 @@ class ModelMixin(torch.nn.Module): def is_gradient_checkpointing(self) -> bool: """ Whether gradient checkpointing is activated for this model or not. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". """ return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) def enable_gradient_checkpointing(self): """ - Activates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". + Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). """ if not self._supports_gradient_checkpointing: raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") @@ -209,10 +203,8 @@ class ModelMixin(torch.nn.Module): def disable_gradient_checkpointing(self): """ - Deactivates gradient checkpointing for the current model. - - Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint - activations". + Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or + *checkpoint activations* in other frameworks). """ if self._supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) @@ -236,13 +228,17 @@ class ModelMixin(torch.nn.Module): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None): r""" - Enable memory efficient attention as implemented in xformers. + Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). - When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference - time. Speed up at training time is not guaranteed. + When this option is enabled, you should observe lower GPU memory usage and a potential speed up during + inference. Speed up during training is not guaranteed. - Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention - is used. + + + ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes + precedent. + + Parameters: attention_op (`Callable`, *optional*): @@ -268,7 +264,7 @@ class ModelMixin(torch.nn.Module): def disable_xformers_memory_efficient_attention(self): r""" - Disable memory efficient attention as implemented in xformers. + Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/). """ self.set_use_memory_efficient_attention_xformers(False) @@ -281,24 +277,24 @@ class ModelMixin(torch.nn.Module): variant: Optional[str] = None, ): """ - Save a model and its configuration file to a directory, so that it can be re-loaded using the - `[`~models.ModelMixin.from_pretrained`]` class method. + Save a model and its configuration file to a directory so that it can be reloaded using the + [`~models.ModelMixin.from_pretrained`] class method. Arguments: save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. + Directory to save a model and its configuration file to. Will be created if it doesn't exist. is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. + Whether the process calling this is the main process or not. Useful during distributed training and you + need to call this function on all processes. In this case, set `is_main_process=True` only on the main + process to avoid race conditions. save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. Can be configured with the environment variable + The function to use to save the state dictionary. Useful during distributed training when you need to + replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `False`): - Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. + If specified, weights are saved in the format `pytorch_model..bin`. """ if safe_serialization and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") @@ -335,107 +331,108 @@ class ModelMixin(torch.nn.Module): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Instantiate a pretrained pytorch model from a pre-trained model configuration. + Instantiate a pretrained PyTorch model from a pretrained model configuration. - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train - the model, you should first set it back in training mode with `model.train()`. - - The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come - pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning - task. - - The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those - weights are discarded. + The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To + train the model, set it back in training mode with `model.train()`. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - Valid model ids should have an organization name, like `google/ddpm-celebahq-256`. - - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g., - `./my_model_directory/`. + - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on + the Hub. + - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved + with [`~ModelMixin.save_pretrained`]. cache_dir (`Union[str, os.PathLike]`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. + Path to a directory where a downloaded pretrained model configuration is cached if the standard cache + is not used. torch_dtype (`str` or `torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype - will be automatically derived from the model's weights. + Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the + dtype is automatically derived from the model's weights. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received files. Will attempt to resume the download if such a - file exists. + Whether or not to resume downloading the model weights and configuration files. If set to `False`, any + incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): + output_loading_info (`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (i.e., do not try to download the model). + Whether to only load local model weights and configuration files or not. If set to `True`, the model + won't be downloaded from the Hub. use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). + The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from + `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. + The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier + allowed by Git. from_flax (`bool`, *optional*, defaults to `False`): Load the model weights from a Flax checkpoint save file. subfolder (`str`, *optional*, defaults to `""`): - In case the relevant files are located inside a subfolder of the model repo (either remote in - huggingface.co or downloaded locally), you can specify the folder name here. - + The subfolder location of a model file within a larger model repository on the Hub or locally. mirror (`str`, *optional*): - Mirror source to accelerate downloads in China. If you are from China and have an accessibility - problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. - Please refer to the mirror site for more information. + Mirror source to resolve accessibility issues if you're downloading a model in China. We do not + guarantee the timeliness or safety of the source, and you should refer to the mirror site for more + information. device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + A map that specifies where each submodule should go. It doesn't need to be defined for each + parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the same device. - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For more information about each option see [designing a device map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each - GPU and the available CPU RAM if unset. + A dictionary device identifier for the maximum memory. Will default to the maximum memory available for + each GPU and the available CPU RAM if unset. offload_folder (`str` or `os.PathLike`, *optional*): - If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + The path to offload weights if `device_map` contains the value `"disk"`. offload_state_dict (`bool`, *optional*): - If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU - RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to - `True` when there is some disk offload. + If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if + the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` + when there is some disk offload. low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): - Speed up model loading by not initializing the weights and only loading the pre-trained weights. This - also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the - model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch, - setting this argument to `True` will raise an error. + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. variant (`str`, *optional*): - If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is - ignored when using `from_flax`. + Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when + loading `from_flax`. use_safetensors (`bool`, *optional*, defaults to `None`): - If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the - `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from - `safetensors` weights. If set to `False`, loading will *not* use `safetensors`. + If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the + `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` + weights. If set to `False`, `safetensors` weights are not loaded. - It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated - models](https://huggingface.co/docs/hub/models-gated#gated-models). + To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with + `huggingface-cli login`. You can also activate the special + ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a + firewalled environment. - + Example: - Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use - this method in a firewalled environment. + ```py + from diffusers import UNet2DConditionModel - + unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") + ``` + If you get the error message below, you need to finetune the weights for your downstream task: + + ```bash + Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: + - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated + You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. + ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) @@ -852,17 +849,27 @@ class ModelMixin(torch.nn.Module): def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: """ - Get number of (optionally, trainable or non-embeddings) parameters in the module. + Get number of (trainable or non-embedding) parameters in the module. Args: only_trainable (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of trainable parameters - + Whether or not to return only the number of trainable parameters. exclude_embeddings (`bool`, *optional*, defaults to `False`): - Whether or not to return only the number of non-embeddings parameters + Whether or not to return only the number of non-embedding parameters. Returns: `int`: The number of parameters. + + Example: + + ```py + from diffusers import UNet2DConditionModel + + model_id = "runwayml/stable-diffusion-v1-5" + unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") + unet.num_parameters(only_trainable=True) + 859520964 + ``` """ if exclude_embeddings: diff --git a/src/diffusers/models/prior_transformer.py b/src/diffusers/models/prior_transformer.py index 58804f2672..47785a93e9 100644 --- a/src/diffusers/models/prior_transformer.py +++ b/src/diffusers/models/prior_transformer.py @@ -16,6 +16,8 @@ from .modeling_utils import ModelMixin @dataclass class PriorTransformerOutput(BaseOutput): """ + The output of [`PriorTransformer`]. + Args: predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): The predicted CLIP image embedding conditioned on the CLIP text embedding input. @@ -26,27 +28,20 @@ class PriorTransformerOutput(BaseOutput): class PriorTransformer(ModelMixin, ConfigMixin): """ - The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the - transformer predicts the image embeddings through a denoising diffusion process. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) - - For more details, see the original paper: https://arxiv.org/abs/2204.06125 + A Prior Transformer model. Parameters: num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use. - embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP - image embeddings and text embeddings are both the same dimension. - num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the - length of the prompt after it has been tokenized. + embedding_dim (`int`, *optional*, defaults to 768): + The dimension of the CLIP embeddings. Image embeddings and text embeddings are both the same dimension. + num_embeddings (`int`, *optional*, defaults to 77): The max number of CLIP embeddings allowed (the + length of the prompt after it has been tokenized). additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the - projected hidden_states. The actual length of the used hidden_states is `num_embeddings + + projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - """ @register_to_config @@ -133,11 +128,15 @@ class PriorTransformer(ModelMixin, ConfigMixin): # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -178,10 +177,12 @@ class PriorTransformer(ModelMixin, ConfigMixin): return_dict: bool = True, ): """ + The [`PriorTransformer`] forward method. + Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): - x_t, the currently predicted image embeddings. - timestep (`torch.long`): + The currently predicted image embeddings. + timestep (`torch.LongTensor`): Current denoising step. proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`): Projected embedding vector the denoising process is conditioned on. @@ -190,13 +191,13 @@ class PriorTransformer(ModelMixin, ConfigMixin): attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`): Text mask for the text embeddings. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain + Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain tuple. Returns: [`~models.prior_transformer.PriorTransformerOutput`] or `tuple`: - [`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. """ batch_size = hidden_states.shape[0] diff --git a/src/diffusers/models/transformer_2d.py b/src/diffusers/models/transformer_2d.py index ec4cb37184..83da16838a 100644 --- a/src/diffusers/models/transformer_2d.py +++ b/src/diffusers/models/transformer_2d.py @@ -29,10 +29,12 @@ from .modeling_utils import ModelMixin @dataclass class Transformer2DModelOutput(BaseOutput): """ + The output of [`Transformer2DModel`]. + Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): - Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions - for the unnoised latent pixels. + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. """ sample: torch.FloatTensor @@ -40,40 +42,30 @@ class Transformer2DModelOutput(BaseOutput): class Transformer2DModel(ModelMixin, ConfigMixin): """ - Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual - embeddings) inputs. - - When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard - transformer action. Finally, reshape to image. - - When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional - embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict - classes of unnoised image. - - Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised - image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + A 2D Transformer model for image-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. + The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. num_vector_embeds (`int`, *optional*): - Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). Includes the class for the masked latent pixel. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. - num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. - The number of diffusion steps used during training. Note that this is fixed at training time as it is used - to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for - up to but not more than steps than `num_embeds_ada_norm`. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. + Configure if the `TransformerBlocks` attention should contain a bias parameter. """ @register_to_config @@ -223,31 +215,34 @@ class Transformer2DModel(ModelMixin, ConfigMixin): return_dict: bool = True, ): """ + The [`Transformer2DModel`] forward method. + Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.LongTensor`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. - encoder_attention_mask ( `torch.Tensor`, *optional* ). - Cross-attention mask, applied to encoder_hidden_states. Two formats supported: - Mask `(batch, sequence_length)` True = keep, False = discard. Bias `(batch, 1, sequence_length)` 0 - = keep, -10000 = discard. - If ndim == 2: will be interpreted as a mask, then converted into a bias consistent with the format + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format above. This bias will be added to the cross-attention scores. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. """ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. diff --git a/src/diffusers/models/transformer_temporal.py b/src/diffusers/models/transformer_temporal.py index ece88b8db2..cfafdb055b 100644 --- a/src/diffusers/models/transformer_temporal.py +++ b/src/diffusers/models/transformer_temporal.py @@ -26,9 +26,11 @@ from .modeling_utils import ModelMixin @dataclass class TransformerTemporalModelOutput(BaseOutput): """ + The output of [`TransformerTemporalModel`]. + Args: - sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`) - Hidden states conditioned on `encoder_hidden_states` input. + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. """ sample: torch.FloatTensor @@ -36,24 +38,23 @@ class TransformerTemporalModelOutput(BaseOutput): class TransformerTemporalModel(ModelMixin, ConfigMixin): """ - Transformer model for video-like data. + A Transformer model for video-like data. Parameters: num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. in_channels (`int`, *optional*): - Pass if the input is continuous. The number of channels in the input and output. + The number of channels in the input and output (specify if the input is **continuous**). num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. - sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. - Note that this is fixed at training time as it is used for learning a number of position embeddings. See - `ImagePositionalEmbeddings`. - activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. attention_bias (`bool`, *optional*): - Configure if the TransformerBlocks' attention should contain a bias parameter. + Configure if the `TransformerBlock` attention should contain a bias parameter. double_self_attention (`bool`, *optional*): - Configure if each TransformerBlock should contain two self-attention layers + Configure if each `TransformerBlock` should contain two self-attention layers. """ @register_to_config @@ -114,25 +115,27 @@ class TransformerTemporalModel(ModelMixin, ConfigMixin): return_dict: bool = True, ): """ + The [`TransformerTemporal`] forward method. + Args: - hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. - When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input - hidden_states + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): Conditional embeddings for cross attention layer. If not given, cross-attention defaults to self-attention. timestep ( `torch.long`, *optional*): - Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): - Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels - conditioning. + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. Returns: - [`~models.transformer_2d.TransformerTemporalModelOutput`] or `tuple`: - [`~models.transformer_2d.TransformerTemporalModelOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. """ # 1. Input batch_frames, channel, height, width = hidden_states.shape diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index 34a1d2b516..9b617388f3 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -28,9 +28,11 @@ from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up @dataclass class UNet1DOutput(BaseOutput): """ + The output of [`UNet1DModel`]. + Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, sample_size)`): - Hidden states output. Output of last layer of model. + The hidden states output from the last layer of the model. """ sample: torch.FloatTensor @@ -38,10 +40,10 @@ class UNet1DOutput(BaseOutput): class UNet1DModel(ModelMixin, ConfigMixin): r""" - UNet1DModel is a 1D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + A 1D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. @@ -49,24 +51,24 @@ class UNet1DModel(ModelMixin, ConfigMixin): out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. extra_in_channels (`int`, *optional*, defaults to 0): Number of additional channels to be added to the input of the first down block. Useful for cases where the - input data has more channels than what the model is initially designed for. + input data has more channels than what the model was initially designed for. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(32, 32, 64)`): Tuple of block output channels. - mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. - out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. - act_fn (`str`, *optional*, defaults to None): optional activation function in UNet blocks. - norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. - layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. - downsample_each_block (`int`, *optional*, defaults to False: - experimental feature for using a UNet without upsampling. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock1D", "DownBlock1DNoSkip", "AttnDownBlock1D")`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(32, 32, 64)`): + Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock1D"`): Block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): Optional output processing block of UNet. + act_fn (`str`, *optional*, defaults to `None`): Optional activation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): The number of groups for normalization. + layers_per_block (`int`, *optional*, defaults to 1): The number of layers per block. + downsample_each_block (`int`, *optional*, defaults to `False`): + Experimental feature for using a UNet without upsampling. """ @register_to_config @@ -197,15 +199,19 @@ class UNet1DModel(ModelMixin, ConfigMixin): return_dict: bool = True, ) -> Union[UNet1DOutput, Tuple]: r""" + The [`UNet1DModel`] forward method. + Args: - sample (`torch.FloatTensor`): `(batch_size, num_channels, sample_size)` noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch_size, num_channels, sample_size)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_1d.UNet1DOutput`] instead of a plain tuple. Returns: - [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + [`~models.unet_1d.UNet1DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_1d.UNet1DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. """ # 1. time diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 4a752fa94a..7077aa8891 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -27,9 +27,11 @@ from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @dataclass class UNet2DOutput(BaseOutput): """ + The output of [`UNet2DModel`]. + Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states output. Output of last layer of model. + The hidden states output from the last layer of the model. """ sample: torch.FloatTensor @@ -37,46 +39,45 @@ class UNet2DOutput(BaseOutput): class UNet2DModel(ModelMixin, ConfigMixin): r""" - UNet2DModel is a 2D UNet model that takes in a noisy sample and a timestep and returns sample shaped output. + A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped output. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. Dimensions must be a multiple of `2 ** (len(block_out_channels) - 1)`. - in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. + in_channels (`int`, *optional*, defaults to 3): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. - flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`True`): Whether to flip sin to cos for fourier time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block - types. + freq_shift (`int`, *optional*, defaults to 0): Frequency shift for Fourier time embedding. + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): + Whether to flip sin to cos for Fourier time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): + Tuple of downsample block types. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`): - The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(224, 448, 672, 896)`): Tuple of block output channels. + Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`. + up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`): + Tuple of block output channels. layers_per_block (`int`, *optional*, defaults to `2`): The number of layers per block. mid_block_scale_factor (`float`, *optional*, defaults to `1`): The scale factor for the mid block. downsample_padding (`int`, *optional*, defaults to `1`): The padding for the downsample convolution. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. - norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization. - norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. + norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. + norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for normalization. resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. - num_class_embeds (`int`, *optional*, defaults to None): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim` when performing class + conditioning with `class_embed_type` equal to `None`. """ @register_to_config @@ -224,17 +225,21 @@ class UNet2DModel(ModelMixin, ConfigMixin): return_dict: bool = True, ) -> Union[UNet2DOutput, Tuple]: r""" + The [`UNet2DModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int): (batch) timesteps + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. class_labels (`torch.FloatTensor`, *optional*, defaults to `None`): Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unet_2d.UNet2DOutput`] instead of a plain tuple. Returns: - [`~models.unet_2d.UNet2DOutput`] or `tuple`: [`~models.unet_2d.UNet2DOutput`] if `return_dict` is True, - otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. + [`~models.unet_2d.UNet2DOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_2d.UNet2DOutput`] is returned, otherwise a `tuple` is + returned where the first element is the sample tensor. """ # 0. center input if necessary if self.config.center_input_sample: diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7bca5c336c..868511ef66 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -50,9 +50,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet2DConditionOutput(BaseOutput): """ + The output of [`UNet2DConditionModel`]. + Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor @@ -60,17 +62,17 @@ class UNet2DConditionOutput(BaseOutput): class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" - UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. @@ -78,9 +80,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the - mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see @@ -92,52 +94,52 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing + If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - encoder_hid_dim (`int`, *optional*, defaults to None): + encoder_hid_dim (`int`, *optional*, defaults to `None`): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to None): - If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to None): + addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. - num_class_embeds (`int`, *optional*, defaults to None): + num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): + time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, default to `None`): + time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, default to `None`): - Optional activation function to use on the time embeddings only one time before they as passed to the rest - of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str, *optional*, default to `None`): + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the - `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will - default to `False`. + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. """ _supports_gradient_checkpointing = True @@ -551,11 +553,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -589,15 +595,15 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `num_attention_heads // slice_size`. In this case, - `num_attention_heads` must be a multiple of `slice_size`. + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ sliceable_head_dims = [] @@ -670,29 +676,28 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" + The [`UNet2DConditionModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): - (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = - discard. Mask will be converted into a bias, which adds large negative values to attention scores - corresponding to "discard" tokens. + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - added_cond_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time - embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and - `addition_embed_type` for more information. + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index 73f7e9263f..352b0b1b5e 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -35,9 +35,11 @@ from .unet_2d_blocks_flax import ( @flax.struct.dataclass class FlaxUNet2DConditionOutput(BaseOutput): """ + The output of [`FlaxUNet2DConditionModel`]. + Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: jnp.ndarray @@ -46,17 +48,17 @@ class FlaxUNet2DConditionOutput(BaseOutput): @flax_register_to_config class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): r""" - FlaxUNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. - This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). - Also, this model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matters related to its general usage and behavior. - Finally, this model supports inherent JAX features such as: + Inherent JAX features such as the following are supported: - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) @@ -69,12 +71,10 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. The corresponding class names will be: "FlaxCrossAttnDownBlock2D", - "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D" - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): - The tuple of upsample blocks to use. The corresponding class names will be: "FlaxUpBlock2D", - "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D" + down_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxCrossAttnDownBlock2D", "FlaxDownBlock2D")`): + The tuple of downsample blocks to use. + up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. layers_per_block (`int`, *optional*, defaults to 2): @@ -91,8 +91,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. use_memory_efficient_attention (`bool`, *optional*, defaults to `False`): - enable memory efficient attention https://arxiv.org/abs/2112.05682 - + Enable memory efficient attention as described [here](https://arxiv.org/abs/2112.05682). """ sample_size: int = 32 diff --git a/src/diffusers/models/unet_3d_condition.py b/src/diffusers/models/unet_3d_condition.py index aa6aa542b1..36dcaf21f8 100644 --- a/src/diffusers/models/unet_3d_condition.py +++ b/src/diffusers/models/unet_3d_condition.py @@ -43,9 +43,11 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): """ + The output of [`UNet3DConditionModel`]. + Args: sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): - Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. """ sample: torch.FloatTensor @@ -53,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput): class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): r""" - UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep - and returns sample shaped output. + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): @@ -66,7 +68,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): The tuple of downsample blocks to use. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): The tuple of upsample blocks to use. block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): The tuple of output channels for each block. @@ -75,7 +77,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing + If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. @@ -291,15 +293,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `num_attention_heads // slice_size`. In this case, - `num_attention_heads` must be a multiple of `slice_size`. + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ sliceable_head_dims = [] @@ -355,11 +357,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -408,21 +414,24 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" + The [`UNet3DConditionModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, num_frames, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet3DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: - [`~models.unet_2d_condition.UNet3DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet3DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index b54e3964f1..edd516dd38 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -30,7 +30,7 @@ class DecoderOutput(BaseOutput): Args: sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. + The decoded output sample from the last layer of the model. """ sample: torch.FloatTensor diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index 9812954db7..b8f5b1d0e3 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -36,9 +36,9 @@ class FlaxDecoderOutput(BaseOutput): Args: sample (`jnp.ndarray` of shape `(batch_size, num_channels, height, width)`): - Decoded output sample of the model. Output of the last layer of the model. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - Parameters `dtype` + The decoded output sample from the last layer of the model. + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. """ sample: jnp.ndarray @@ -720,40 +720,43 @@ class FlaxDiagonalGaussianDistribution(object): @flax_register_to_config class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): r""" - Flax Implementation of Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational - Bayes by Diederik P. Kingma and Max Welling. + Flax implementation of a VAE model with KL loss for decoding latent representations. + + This model inherits from [`FlaxModelMixin`]. Check the superclass documentation for it's generic methods + implemented for all models (such as downloading or saving). This model is a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module) - subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to + subclass. Use it as a regular Flax Linen module and refer to the Flax documentation for all matter related to its general usage and behavior. - Finally, this model supports inherent JAX features such as: + Inherent JAX features such as the following are supported: + - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) Parameters: - in_channels (:obj:`int`, *optional*, defaults to 3): - Input channels - out_channels (:obj:`int`, *optional*, defaults to 3): - Output channels - down_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): - DownEncoder block type - up_block_types (:obj:`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): - UpDecoder block type - block_out_channels (:obj:`Tuple[str]`, *optional*, defaults to `(64,)`): - Tuple containing the number of output channels for each block - layers_per_block (:obj:`int`, *optional*, defaults to `2`): - Number of Resnet layer for each block - act_fn (:obj:`str`, *optional*, defaults to `silu`): - Activation function - latent_channels (:obj:`int`, *optional*, defaults to `4`): - Latent space channels - norm_num_groups (:obj:`int`, *optional*, defaults to `32`): - Norm num group - sample_size (:obj:`int`, *optional*, defaults to 32): - Sample input size + in_channels (`int`, *optional*, defaults to 3): + Number of channels in the input image. + out_channels (`int`, *optional*, defaults to 3): + Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to `(DownEncoderBlock2D)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `(UpDecoderBlock2D)`): + Tuple of upsample block types. + block_out_channels (`Tuple[str]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. + layers_per_block (`int`, *optional*, defaults to `2`): + Number of ResNet layer for each block. + act_fn (`str`, *optional*, defaults to `silu`): + The activation function to use. + latent_channels (`int`, *optional*, defaults to `4`): + Number of channels in the latent space. + norm_num_groups (`int`, *optional*, defaults to `32`): + The number of groups for normalization. + sample_size (`int`, *optional*, defaults to 32): + Sample input size. scaling_factor (`float`, *optional*, defaults to 0.18215): The component-wise standard deviation of the trained latent space computed using the first batch of the training set. This is used to scale the latent space to have unit variance when training the diffusion @@ -761,8 +764,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. - dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32): - parameters `dtype` + dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`): + The `dtype` of the parameters. """ in_channels: int = 3 out_channels: int = 3 diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index 73158294ee..32f944daca 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -30,31 +30,31 @@ class VQEncoderOutput(BaseOutput): Args: latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - Encoded output sample of the model. Output of the last layer of the model. + The encoded output sample from the last layer of the model. """ latents: torch.FloatTensor class VQModel(ModelMixin, ConfigMixin): - r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray - Kavukcuoglu. + r""" + A VQ-VAE model for decoding latent representations. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the model (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: in_channels (int, *optional*, defaults to 3): Number of channels in the input image. out_channels (int, *optional*, defaults to 3): Number of channels in the output. - down_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. - up_block_types (`Tuple[str]`, *optional*, defaults to : - obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. - block_out_channels (`Tuple[int]`, *optional*, defaults to : - obj:`(64,)`): Tuple of block output channels. + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): + Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): + Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): + Tuple of block output channels. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. - sample_size (`int`, *optional*, defaults to `32`): TODO + sample_size (`int`, *optional*, defaults to `32`): Sample input size. num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. scaling_factor (`float`, *optional*, defaults to `0.18215`): @@ -143,10 +143,17 @@ class VQModel(ModelMixin, ConfigMixin): def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: r""" + The [`VQModel`] forward method. + Args: sample (`torch.FloatTensor`): Input sample. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vq_model.VQEncoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` + is returned. """ x = sample h = self.encode(x).latents diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0dd2351e60..dd5b7f77c1 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -153,17 +153,17 @@ def get_up_block( # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" - UNetFlatConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a - timestep and returns sample shaped output. + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. - This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library - implements for all the models (such as downloading or saving, etc.) + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). Parameters: sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip the sin to cos in the time embedding. @@ -171,9 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`): The tuple of downsample blocks to use. mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`): - The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`, will skip - the mid block layer if `None`. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`): + Block type for middle of UNet, it can be either `UNetMidBlockFlatCrossAttn` or + `UNetMidBlockFlatSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat")`): The tuple of upsample blocks to use. only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): Whether to include self-attention in the basic transformer blocks, see @@ -185,52 +185,52 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, it will skip the normalization and activation layers in post-processing + If `None`, normalization and activation layers is skipped in post-processing. norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): The dimension of the cross attention features. - encoder_hid_dim (`int`, *optional*, defaults to None): + encoder_hid_dim (`int`, *optional*, defaults to `None`): If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to None): - If given, the `encoder_hidden_states` and potentially other embeddings will be down-projected to text + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. num_attention_heads (`int`, *optional*): The number of attention heads. If not defined, defaults to `attention_head_dim` resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to None): + for ResNet blocks (see [`~models.resnet.ResnetBlockFlat`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to None): + addition_embed_type (`str`, *optional*, defaults to `None`): Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or "text". "text" will use the `TextTimeEmbedding` layer. - num_class_embeds (`int`, *optional*, defaults to None): + num_class_embeds (`int`, *optional*, defaults to `None`): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, default to `positional`): + time_embedding_type (`str`, *optional*, defaults to `positional`): The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, default to `None`): + time_embedding_dim (`int`, *optional*, defaults to `None`): An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, default to `None`): - Optional activation function to use on the time embeddings only one time before they as passed to the rest - of the unet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str, *optional*, default to `None`): + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, default to `None`): - The dimension of `cond_proj` layer in timestep embedding. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when - using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time embeddings with the class embeddings. mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): Whether to use cross attention with the mid block when using the `UNetMidBlockFlatSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is None, the - `only_cross_attention` value will be used as the value for `mid_block_only_cross_attention`. Else, it will - default to `False`. + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. """ _supports_gradient_checkpointing = True @@ -656,11 +656,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" + Sets the attention processor to use to compute attention. + Parameters: - `processor (`dict` of `AttentionProcessor` or `AttentionProcessor`): + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor - of **all** `Attention` layers. - In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors.: + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) @@ -694,15 +698,15 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maximum amount of memory will be saved by running only one slice at a time. If a number is - provided, uses as many slices as `num_attention_heads // slice_size`. In this case, - `num_attention_heads` must be a multiple of `slice_size`. + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. """ sliceable_head_dims = [] @@ -775,29 +779,28 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): return_dict: bool = True, ) -> Union[UNet2DConditionOutput, Tuple]: r""" + The [`UNetFlatConditionModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. encoder_attention_mask (`torch.Tensor`): - (batch, sequence_length) cross-attention mask, applied to encoder_hidden_states. True = keep, False = - discard. Mask will be converted into a bias, which adds large negative values to attention scores - corresponding to "discard" tokens. + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). - added_cond_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified includes additonal conditions that can be used for additonal time - embeddings or encoder hidden states projections. See the configurations `encoder_hid_dim_type` and - `addition_embed_type` for more information. + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).