mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add Stable Diffusion 3 (#8483)
* up * add sd3 * update * update * add tests * fix copies * fix docs * update * add dreambooth lora * add LoRA * update * update * update * update * import fix * update * Update src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * import fix 2 * update * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/models/autoencoders/autoencoder_kl.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * update * update * update * fix ckpt id * fix more ids * update * missing doc * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * update' * fix * update * Update src/diffusers/models/autoencoders/autoencoder_kl.py * Update src/diffusers/models/autoencoders/autoencoder_kl.py * note on gated access. * requirements * licensing --------- Co-authored-by: sayakpaul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
This commit is contained in:
@@ -107,7 +107,8 @@
|
||||
title: Create a dataset for training
|
||||
- local: training/adapt_a_model
|
||||
title: Adapt a model to a new task
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: training/unconditional_training
|
||||
title: Unconditional image generation
|
||||
- local: training/text2image
|
||||
@@ -125,8 +126,8 @@
|
||||
- local: training/instructpix2pix
|
||||
title: InstructPix2Pix
|
||||
title: Models
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: training/text_inversion
|
||||
title: Textual Inversion
|
||||
- local: training/dreambooth
|
||||
@@ -140,7 +141,6 @@
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
title: Methods
|
||||
isExpanded: false
|
||||
title: Training
|
||||
- sections:
|
||||
- local: optimization/fp16
|
||||
@@ -187,7 +187,8 @@
|
||||
title: Evaluating Diffusion Models
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/configuration
|
||||
title: Configuration
|
||||
- local: api/logging
|
||||
@@ -195,8 +196,8 @@
|
||||
- local: api/outputs
|
||||
title: Outputs
|
||||
title: Main Classes
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
@@ -210,8 +211,8 @@
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
title: Loaders
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- local: api/models/unet
|
||||
@@ -246,13 +247,15 @@
|
||||
title: HunyuanDiT2DModel
|
||||
- local: api/models/transformer_temporal
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/sd3_transformer2d
|
||||
title: SD3Transformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
title: PriorTransformer
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
title: Models
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/amused
|
||||
@@ -350,6 +353,8 @@
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
@@ -382,8 +387,8 @@
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Pipelines
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
@@ -414,6 +419,8 @@
|
||||
title: EulerAncestralDiscreteScheduler
|
||||
- local: api/schedulers/euler
|
||||
title: EulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_euler_discrete
|
||||
title: FlowMatchEulerDiscreteScheduler
|
||||
- local: api/schedulers/heun
|
||||
title: HeunDiscreteScheduler
|
||||
- local: api/schedulers/ipndm
|
||||
@@ -443,8 +450,8 @@
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: VQDiffusionScheduler
|
||||
title: Schedulers
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: api/internal_classes_overview
|
||||
title: Overview
|
||||
- local: api/attnprocessor
|
||||
@@ -460,5 +467,4 @@
|
||||
- local: api/video_processor
|
||||
title: Video Processor
|
||||
title: Internal classes
|
||||
isExpanded: false
|
||||
title: API
|
||||
|
||||
19
docs/source/en/api/models/sd3_transformer2d.md
Normal file
19
docs/source/en/api/models/sd3_transformer2d.md
Normal file
@@ -0,0 +1,19 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# SD3 Transformer Model
|
||||
|
||||
The Transformer model introduced in [Stable Diffusion 3](https://hf.co/papers/2403.03206). Its novelty lies in the MMDiT transformer block.
|
||||
|
||||
## SD3Transformer2DModel
|
||||
|
||||
[[autodoc]] SD3Transformer2DModel
|
||||
@@ -0,0 +1,230 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Stable Diffusion 3
|
||||
|
||||
Stable Diffusion 3 (SD3) was proposed in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206.pdf) by Patrick Esser, Sumith Kulal, Andreas Blattmann, Rahim Entezari, Jonas Muller, Harry Saini, Yam Levi, Dominik Lorenz, Axel Sauer, Frederic Boesel, Dustin Podell, Tim Dockhorn, Zion English, Kyle Lacey, Alex Goodwin, Yannik Marek, and Robin Rombach.
|
||||
|
||||
The abstract from the paper is:
|
||||
|
||||
*Diffusion models create data from noise by inverting the forward paths of data towards noise and have emerged as a powerful generative modeling technique for high-dimensional, perceptual data such as images and videos. Rectified flow is a recent generative model formulation that connects data and noise in a straight line. Despite its better theoretical properties and conceptual simplicity, it is not yet decisively established as standard practice. In this work, we improve existing noise sampling techniques for training rectified flow models by biasing them towards perceptually relevant scales. Through a large-scale study, we demonstrate the superior performance of this approach compared to established diffusion formulations for high-resolution text-to-image synthesis. Additionally, we present a novel transformer-based architecture for text-to-image generation that uses separate weights for the two modalities and enables a bidirectional flow of information between image and text tokens, improving text comprehension typography, and human preference ratings. We demonstrate that this architecture follows predictable scaling trends and correlates lower validation loss to improved text-to-image synthesis as measured by various metrics and human evaluations.*
|
||||
|
||||
|
||||
## Usage Example
|
||||
|
||||
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
|
||||
|
||||
Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
The SD3 pipeline uses three text encoders to generate an image. Model offloading is necessary in order for it to run on most commodity hardware. Please use the `torch.float16` data type for additional memory savings.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of a cat holding a sign that says hello world",
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=7.0,
|
||||
).images[0]
|
||||
|
||||
image.save("sd3_hello_world.png")
|
||||
```
|
||||
|
||||
## Memory Optimisations for SD3
|
||||
|
||||
SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
|
||||
|
||||
### Running Inference with Model Offloading
|
||||
|
||||
The most basic memory optimization available in Diffusers allows you to offload the components of the model to CPU during inference in order to save memory, while seeing a slight increase in inference latency. Model offloading will only move a model component onto the GPU when it needs to be executed, while keeping the remaining components on the CPU.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of a cat holding a sign that says hello world",
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=7.0,
|
||||
).images[0]
|
||||
|
||||
image.save("sd3_hello_world.png")
|
||||
```
|
||||
|
||||
### Dropping the T5 Text Encoder during Inference
|
||||
|
||||
Removing the memory-intensive 4.7B parameter T5-XXL text encoder during inference can significantly decrease the memory requirements for SD3 with only a slight loss in performance.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
text_encoder_3=None,
|
||||
tokenizer_3=None,
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of a cat holding a sign that says hello world",
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=7.0,
|
||||
).images[0]
|
||||
|
||||
image.save("sd3_hello_world-no-T5.png")
|
||||
```
|
||||
|
||||
### Using a Quantized Version of the T5 Text Encoder
|
||||
|
||||
We can leverage the `bitsandbytes` library to load and quantize the T5-XXL text encoder to 8-bit precision. This allows you to keep using all three text encoders while only slightly impacting performance.
|
||||
|
||||
First install the `bitsandbytes` library.
|
||||
|
||||
```shell
|
||||
pip install bitsandbytes
|
||||
```
|
||||
|
||||
Then load the T5-XXL model using the `BitsAndBytesConfig`.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import T5EncoderModel, BitsAndBytesConfig
|
||||
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
text_encoder = T5EncoderModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_3",
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
model_id,
|
||||
text_encoder_3=text_encoder,
|
||||
device_map="balanced",
|
||||
torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
image = pipe(
|
||||
prompt="a photo of a cat holding a sign that says hello world",
|
||||
negative_prompt="",
|
||||
num_inference_steps=28,
|
||||
height=1024,
|
||||
width=1024,
|
||||
guidance_scale=7.0,
|
||||
).images[0]
|
||||
|
||||
image.save("sd3_hello_world-8bit-T5.png")
|
||||
```
|
||||
|
||||
You can find the end-to-end script [here](https://gist.github.com/sayakpaul/82acb5976509851f2db1a83456e504f1).
|
||||
|
||||
## Performance Optimizations for SD3
|
||||
|
||||
### Using Torch Compile to Speed Up Inference
|
||||
|
||||
Using compiled components in the SD3 pipeline can speed up inference by as much as 4X. The following code snippet demonstrates how to compile the Transformer and VAE components of the SD3 pipeline.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
torch._inductor.config.conv_1x1_as_mm = True
|
||||
torch._inductor.config.coordinate_descent_tuning = True
|
||||
torch._inductor.config.epilogue_fusion = False
|
||||
torch._inductor.config.coordinate_descent_check_all_directions = True
|
||||
|
||||
pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
pipe.transformer.to(memory_format=torch.channels_last)
|
||||
pipe.vae.to(memory_format=torch.channels_last)
|
||||
|
||||
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
|
||||
pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
|
||||
|
||||
# Warm Up
|
||||
prompt = "a photo of a cat holding a sign that says hello world",
|
||||
for _ in range(3):
|
||||
_ = pipe(prompt=prompt, generator=torch.manual_seed(1))
|
||||
|
||||
# Run Inference
|
||||
image = pipe(prompt=prompt, generator=torch.manual_seed(1)).images[0]
|
||||
image.save("sd3_hello_world.png")
|
||||
```
|
||||
|
||||
Check out the full script [here](https://gist.github.com/sayakpaul/508d89d7aad4f454900813da5d42ca97).
|
||||
|
||||
## Loading the original checkpoints via `from_single_file`
|
||||
|
||||
The `SD3Transformer2DModel` and `StableDiffusion3Pipeline` classes support loading the original checkpoints via the `from_single_file` method. This method allows you to load the original checkpoint files that were used to train the models.
|
||||
|
||||
## Loading the original checkpoints for the `SD3Transformer2DModel`
|
||||
|
||||
```python
|
||||
from diffusers import SD3Transformer2DModel
|
||||
|
||||
model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium.safetensors")
|
||||
```
|
||||
|
||||
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusion3Pipeline
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
|
||||
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
|
||||
</Tip>
|
||||
|
||||
## StableDiffusion3Pipeline
|
||||
|
||||
[[autodoc]] StableDiffusion3Pipeline
|
||||
- all
|
||||
- __call__
|
||||
18
docs/source/en/api/schedulers/flow_match_euler_discrete.md
Normal file
18
docs/source/en/api/schedulers/flow_match_euler_discrete.md
Normal file
@@ -0,0 +1,18 @@
|
||||
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# FlowMatchEulerDiscreteScheduler
|
||||
|
||||
`FlowMatchEulerDiscreteScheduler` is based on the flow-matching sampling introduced in [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
|
||||
|
||||
## FlowMatchEulerDiscreteScheduler
|
||||
[[autodoc]] FlowMatchEulerDiscreteScheduler
|
||||
141
examples/dreambooth/README_sd3.md
Normal file
141
examples/dreambooth/README_sd3.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# DreamBooth training example for Stable Diffusion 3 (SD3)
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few (3~5) images of a subject.
|
||||
|
||||
The `train_dreambooth_sd3.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206). We also provide a LoRA implementation in the `train_dreambooth_lora_sd3.py` script.
|
||||
|
||||
> [!NOTE]
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/dreambooth` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sd3.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
|
||||
### Dog toy example
|
||||
|
||||
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./dog"
|
||||
snapshot_download(
|
||||
"diffusers/dog-example",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
|
||||
|
||||
Now, we can launch training using:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sd3"
|
||||
|
||||
accelerate launch train_dreambooth_sd3.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="fp16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
> [!TIP]
|
||||
> You can pass `--use_8bit_adam` to reduce the memory requirements of training. Make sure to install `bitsandbytes` if you want to do so.
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
To perform DreamBooth with LoRA, run:
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
export INSTANCE_DIR="dog"
|
||||
export OUTPUT_DIR="trained-sd3-lora"
|
||||
|
||||
accelerate launch train_dreambooth_sd3.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--mixed_precision="fp16" \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--learning_rate=1e-5 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=25 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
7
examples/dreambooth/requirements_sd3.txt
Normal file
7
examples/dreambooth/requirements_sd3.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
accelerate>=0.31.0
|
||||
torchvision
|
||||
transformers>=4.41.2
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft== 0.11.1
|
||||
1665
examples/dreambooth/train_dreambooth_lora_sd3.py
Normal file
1665
examples/dreambooth/train_dreambooth_lora_sd3.py
Normal file
File diff suppressed because it is too large
Load Diff
1760
examples/dreambooth/train_dreambooth_sd3.py
Normal file
1760
examples/dreambooth/train_dreambooth_sd3.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,7 @@ else:
|
||||
"MultiAdapter",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"SD3Transformer2DModel",
|
||||
"StableCascadeUNet",
|
||||
"T2IAdapter",
|
||||
"T5FilmDecoder",
|
||||
@@ -156,6 +157,7 @@ else:
|
||||
"EDMEulerScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -276,6 +278,8 @@ else:
|
||||
"StableCascadeCombinedPipeline",
|
||||
"StableCascadeDecoderPipeline",
|
||||
"StableCascadePriorPipeline",
|
||||
"StableDiffusion3Img2ImgPipeline",
|
||||
"StableDiffusion3Pipeline",
|
||||
"StableDiffusionAdapterPipeline",
|
||||
"StableDiffusionAttendAndExcitePipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
@@ -497,6 +501,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiAdapter,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3Transformer2DModel,
|
||||
T2IAdapter,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
@@ -559,6 +564,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EDMEulerScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -660,6 +666,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableCascadeCombinedPipeline,
|
||||
StableCascadeDecoderPipeline,
|
||||
StableCascadePriorPipeline,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
StableDiffusion3Pipeline,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionAttendAndExcitePipeline,
|
||||
StableDiffusionControlNetImg2ImgPipeline,
|
||||
|
||||
@@ -86,6 +86,7 @@ class VaeImageProcessor(ConfigMixin):
|
||||
self,
|
||||
do_resize: bool = True,
|
||||
vae_scale_factor: int = 8,
|
||||
vae_latent_channels: int = 4,
|
||||
resample: str = "lanczos",
|
||||
do_normalize: bool = True,
|
||||
do_binarize: bool = False,
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_torch_available():
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin"]
|
||||
_import_structure["lora"] = ["LoraLoaderMixin", "StableDiffusionXLLoraLoaderMixin", "SD3LoraLoaderMixin"]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
|
||||
|
||||
@@ -74,7 +74,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import IPAdapterMixin
|
||||
from .lora import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from .lora import LoraLoaderMixin, SD3LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
|
||||
from .single_file import FromSingleFileMixin
|
||||
from .textual_inversion import TextualInversionLoaderMixin
|
||||
|
||||
|
||||
@@ -1337,3 +1337,393 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
if getattr(self.text_encoder_2, "peft_config", None) is not None:
|
||||
del self.text_encoder_2.peft_config
|
||||
self.text_encoder_2._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
class SD3LoraLoaderMixin:
|
||||
r"""
|
||||
Load LoRA layers into [`SD3Transformer2DModel`].
|
||||
"""
|
||||
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(
|
||||
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
|
||||
See [`~loaders.LoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded
|
||||
into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
|
||||
|
||||
This function is experimental and might change in the future.
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
||||
incompletely downloaded files are deleted.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# UNet and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The keys can either be indexed directly
|
||||
into the unet or prefixed with an additional `unet` which can be used to distinguish between text
|
||||
encoder lora layers.
|
||||
transformer (`SD3Transformer2DModel`):
|
||||
The Transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
"""
|
||||
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
|
||||
state_dict = {
|
||||
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys
|
||||
}
|
||||
|
||||
if len(state_dict.keys()) > 0:
|
||||
if adapter_name in getattr(transformer, "peft_config", {}):
|
||||
raise ValueError(
|
||||
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
|
||||
)
|
||||
|
||||
rank = {}
|
||||
for key, val in state_dict.items():
|
||||
if "lora_B" in key:
|
||||
rank[key] = val.shape[1]
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(transformer)
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
|
||||
|
||||
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
||||
if unexpected_keys:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not transformer_lora_layers:
|
||||
raise ValueError("You must pass `transformer_lora_layers`.")
|
||||
|
||||
if transformer_lora_layers:
|
||||
state_dict.update(pack_weights(transformer_lora_layers, cls.transformer_name))
|
||||
|
||||
# Save the model
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
|
||||
recurse_remove_peft_layers(transformer)
|
||||
if hasattr(transformer, "peft_config"):
|
||||
del transformer.peft_config
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
@@ -234,7 +234,7 @@ def _download_diffusers_model_config_from_hub(
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt"]
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
|
||||
@@ -24,6 +24,7 @@ from .single_file_utils import (
|
||||
convert_controlnet_checkpoint,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
@@ -64,6 +65,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from io import BytesIO
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
@@ -65,11 +66,14 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"inpainting": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"clip": "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"clip_sdxl": "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"clip_sd3": "text_encoders.clip_l.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"open_clip": "cond_stage_model.model.token_embedding.weight",
|
||||
"open_clip_sdxl": "conditioner.embedders.1.model.positional_embedding",
|
||||
"open_clip_sdxl_refiner": "conditioner.embedders.0.model.text_projection",
|
||||
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
|
||||
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
|
||||
"stable_cascade_stage_c": "clip_txt_mapper.weight",
|
||||
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -96,6 +100,9 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-cascade-prior",
|
||||
"subfolder": "prior_lite",
|
||||
},
|
||||
"sd3": {
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-3-medium-diffusers",
|
||||
},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -242,7 +249,11 @@ LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
PLAYGROUND_VAE_SCALING_FACTOR = 0.5
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = [
|
||||
"cond_stage_model.transformer.",
|
||||
"conditioner.embedders.0.transformer.",
|
||||
"text_encoders.clip_l.transformer.",
|
||||
]
|
||||
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
||||
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
||||
|
||||
@@ -366,6 +377,13 @@ def is_clip_sdxl_model(checkpoint):
|
||||
return False
|
||||
|
||||
|
||||
def is_clip_sd3_model(checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["clip_sd3"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def is_open_clip_model(checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["open_clip"] in checkpoint:
|
||||
return True
|
||||
@@ -380,8 +398,12 @@ def is_open_clip_sdxl_model(checkpoint):
|
||||
return False
|
||||
|
||||
|
||||
def is_open_clip_sd3_model(checkpoint):
|
||||
is_open_clip_sdxl_refiner_model(checkpoint)
|
||||
|
||||
|
||||
def is_open_clip_sdxl_refiner_model(checkpoint):
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
||||
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -391,9 +413,11 @@ def is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
is_clip_in_checkpoint = any(
|
||||
[
|
||||
is_clip_model(checkpoint),
|
||||
is_clip_sd3_model(checkpoint),
|
||||
is_open_clip_model(checkpoint),
|
||||
is_open_clip_sdxl_model(checkpoint),
|
||||
is_open_clip_sdxl_refiner_model(checkpoint),
|
||||
is_open_clip_sd3_model(checkpoint),
|
||||
]
|
||||
)
|
||||
if (
|
||||
@@ -456,6 +480,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "stable_cascade_stage_b"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint:
|
||||
model_type = "sd3"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -1364,6 +1391,10 @@ def create_diffusers_clip_model_from_ldm(
|
||||
prefix = "conditioner.embedders.0.model."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
|
||||
elif is_open_clip_sd3_model(checkpoint):
|
||||
prefix = "text_encoders.clip_g.transformer."
|
||||
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
||||
|
||||
else:
|
||||
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
||||
|
||||
@@ -1559,3 +1590,212 @@ def _legacy_load_safety_checker(local_files_only, torch_dtype):
|
||||
)
|
||||
|
||||
return {"safety_checker": safety_checker, "feature_extractor": feature_extractor}
|
||||
|
||||
|
||||
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
||||
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_sd3_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "joint_blocks" in k))[-1] + 1 # noqa: C401
|
||||
caption_projection_dim = 1536
|
||||
|
||||
# Positional and patch embeddings.
|
||||
converted_state_dict["pos_embed.pos_embed"] = checkpoint.pop("pos_embed")
|
||||
converted_state_dict["pos_embed.proj.weight"] = checkpoint.pop("x_embedder.proj.weight")
|
||||
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("x_embedder.proj.bias")
|
||||
|
||||
# Timestep embeddings.
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.0.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_1.bias"] = checkpoint.pop("t_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.weight"] = checkpoint.pop(
|
||||
"t_embedder.mlp.2.weight"
|
||||
)
|
||||
converted_state_dict["time_text_embed.timestep_embedder.linear_2.bias"] = checkpoint.pop("t_embedder.mlp.2.bias")
|
||||
|
||||
# Context projections.
|
||||
converted_state_dict["context_embedder.weight"] = checkpoint.pop("context_embedder.weight")
|
||||
converted_state_dict["context_embedder.bias"] = checkpoint.pop("context_embedder.bias")
|
||||
|
||||
# Pooled context projection.
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.weight"] = checkpoint.pop("y_embedder.mlp.0.weight")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_1.bias"] = checkpoint.pop("y_embedder.mlp.0.bias")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.weight"] = checkpoint.pop("y_embedder.mlp.2.weight")
|
||||
converted_state_dict["time_text_embed.text_embedder.linear_2.bias"] = checkpoint.pop("y_embedder.mlp.2.bias")
|
||||
|
||||
# Transformer blocks 🎸.
|
||||
for i in range(num_layers):
|
||||
# Q, K, V
|
||||
sample_q, sample_k, sample_v = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.weight"), 3, dim=0
|
||||
)
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.x_block.attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
checkpoint.pop(f"joint_blocks.{i}.context_block.attn.qkv.bias"), 3, dim=0
|
||||
)
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_q.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_k.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.weight"] = torch.cat([sample_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_v.bias"] = torch.cat([sample_v_bias])
|
||||
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.weight"] = torch.cat([context_v])
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_out.0.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.attn.proj.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.proj.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.attn.to_add_out.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.attn.proj.bias"
|
||||
)
|
||||
|
||||
# norms.
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1.linear.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.adaLN_modulation.1.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"
|
||||
)
|
||||
else:
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.weight"),
|
||||
dim=caption_projection_dim,
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.norm1_context.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop(f"joint_blocks.{i}.context_block.adaLN_modulation.1.bias"),
|
||||
dim=caption_projection_dim,
|
||||
)
|
||||
|
||||
# ffs.
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff.net.2.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.x_block.mlp.fc2.bias"
|
||||
)
|
||||
if not (i == num_layers - 1):
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc1.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.0.proj.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc1.bias"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.weight"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc2.weight"
|
||||
)
|
||||
converted_state_dict[f"transformer_blocks.{i}.ff_context.net.2.bias"] = checkpoint.pop(
|
||||
f"joint_blocks.{i}.context_block.mlp.fc2.bias"
|
||||
)
|
||||
|
||||
# Final blocks.
|
||||
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
converted_state_dict["norm_out.linear.weight"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.weight"), dim=caption_projection_dim
|
||||
)
|
||||
converted_state_dict["norm_out.linear.bias"] = swap_scale_shift(
|
||||
checkpoint.pop("final_layer.adaLN_modulation.1.bias"), dim=caption_projection_dim
|
||||
)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def is_t5_in_single_file(checkpoint):
|
||||
if "text_encoders.t5xxl.transformer.shared.weight" in checkpoint:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
text_model_dict = {}
|
||||
|
||||
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
|
||||
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
if key.startswith(prefix):
|
||||
diffusers_key = key.replace(prefix, "")
|
||||
text_model_dict[diffusers_key] = checkpoint.get(key)
|
||||
|
||||
return text_model_dict
|
||||
|
||||
|
||||
def create_diffusers_t5_model_from_checkpoint(
|
||||
cls,
|
||||
checkpoint,
|
||||
subfolder="",
|
||||
config=None,
|
||||
torch_dtype=None,
|
||||
local_files_only=None,
|
||||
):
|
||||
if config:
|
||||
config = {"pretrained_model_name_or_path": config}
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
|
||||
model_config = cls.config_class.from_pretrained(**config, subfolder=subfolder, local_files_only=local_files_only)
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls(model_config)
|
||||
|
||||
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
|
||||
|
||||
if is_accelerate_available():
|
||||
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
model.load_state_dict(diffusers_format_checkpoint)
|
||||
|
||||
@@ -43,6 +43,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
|
||||
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
|
||||
@@ -82,6 +83,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3Transformer2DModel,
|
||||
T5FilmDecoder,
|
||||
Transformer2DModel,
|
||||
TransformerTemporalModel,
|
||||
|
||||
@@ -20,7 +20,7 @@ from torch import nn
|
||||
from ..utils import deprecate, logging
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from .activations import GEGLU, GELU, ApproximateGELU
|
||||
from .attention_processor import Attention
|
||||
from .attention_processor import Attention, JointAttnProcessor2_0
|
||||
from .embeddings import SinusoidalPositionalEmbedding
|
||||
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
|
||||
|
||||
@@ -85,6 +85,130 @@ class GatedSelfAttentionDense(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class JointTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
|
||||
processing of `context` conditions.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
|
||||
|
||||
self.norm1 = AdaLayerNormZero(dim)
|
||||
|
||||
if context_norm_type == "ada_norm_continous":
|
||||
self.norm1_context = AdaLayerNormContinuous(
|
||||
dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
|
||||
)
|
||||
elif context_norm_type == "ada_norm_zero":
|
||||
self.norm1_context = AdaLayerNormZero(dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
|
||||
)
|
||||
if hasattr(F, "scaled_dot_product_attention"):
|
||||
processor = JointAttnProcessor2_0()
|
||||
else:
|
||||
raise ValueError(
|
||||
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
|
||||
)
|
||||
self.attn = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim // num_attention_heads,
|
||||
heads=num_attention_heads,
|
||||
out_dim=attention_head_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
bias=True,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
if not context_pre_only:
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
else:
|
||||
self.norm2_context = None
|
||||
self.ff_context = None
|
||||
|
||||
# let chunk size default to None
|
||||
self._chunk_size = None
|
||||
self._chunk_dim = 0
|
||||
|
||||
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
|
||||
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
||||
# Sets chunk feed-forward
|
||||
self._chunk_size = chunk_size
|
||||
self._chunk_dim = dim
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
|
||||
):
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
||||
|
||||
if self.context_pre_only:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
else:
|
||||
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, emb=temb
|
||||
)
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
|
||||
else:
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = hidden_states + ff_output
|
||||
|
||||
# Process attention outputs for the `encoder_hidden_states`.
|
||||
if self.context_pre_only:
|
||||
encoder_hidden_states = None
|
||||
else:
|
||||
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
||||
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
||||
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
if self._chunk_size is not None:
|
||||
# "feed_forward_chunk_size" can be used to save memory
|
||||
context_ff_output = _chunked_feed_forward(
|
||||
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
|
||||
)
|
||||
else:
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
|
||||
@@ -116,6 +116,7 @@ class Attention(nn.Module):
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
out_dim: int = None,
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
@@ -130,6 +131,7 @@ class Attention(nn.Module):
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
# we make use of this private variable to know whether this class is loaded
|
||||
# with an deprecated state dict so that we can convert it on the fly
|
||||
@@ -207,11 +209,16 @@ class Attention(nn.Module):
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
|
||||
# set attention processor
|
||||
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
||||
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
||||
@@ -1075,6 +1082,164 @@ class AttnAddedKVProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class JointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
# `context` projections.
|
||||
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedJointAttnProcessor2_0:
|
||||
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
context_input_ndim = encoder_hidden_states.ndim
|
||||
if context_input_ndim == 4:
|
||||
batch_size, channel, height, width = encoder_hidden_states.shape
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
qkv = attn.to_qkv(hidden_states)
|
||||
split_size = qkv.shape[-1] // 3
|
||||
query, key, value = torch.split(qkv, split_size, dim=-1)
|
||||
|
||||
# `context` projections.
|
||||
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
|
||||
split_size = encoder_qkv.shape[-1] // 3
|
||||
(
|
||||
encoder_hidden_states_query_proj,
|
||||
encoder_hidden_states_key_proj,
|
||||
encoder_hidden_states_value_proj,
|
||||
) = torch.split(encoder_qkv, split_size, dim=-1)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
|
||||
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
|
||||
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
hidden_states, encoder_hidden_states = (
|
||||
hidden_states[:, : residual.shape[1]],
|
||||
hidden_states[:, residual.shape[1] :],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
if not attn.context_pre_only:
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
if context_input_ndim == 4:
|
||||
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class XFormersAttnAddedKVProcessor:
|
||||
r"""
|
||||
Processor for implementing memory efficient attention using xFormers.
|
||||
|
||||
@@ -81,9 +81,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
norm_num_groups: int = 32,
|
||||
sample_size: int = 32,
|
||||
scaling_factor: float = 0.18215,
|
||||
shift_factor: Optional[float] = None,
|
||||
latents_mean: Optional[Tuple[float]] = None,
|
||||
latents_std: Optional[Tuple[float]] = None,
|
||||
force_upcast: float = True,
|
||||
use_quant_conv: bool = True,
|
||||
use_post_quant_conv: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -110,8 +113,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
|
||||
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
|
||||
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
@@ -245,13 +248,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
|
||||
tuple.
|
||||
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
|
||||
returned.
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
||||
return self.tiled_encode(x, return_dict=return_dict)
|
||||
@@ -262,7 +263,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
else:
|
||||
h = self.encoder(x)
|
||||
|
||||
moments = self.quant_conv(h)
|
||||
if self.quant_conv is not None:
|
||||
moments = self.quant_conv(h)
|
||||
else:
|
||||
moments = h
|
||||
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
if not return_dict:
|
||||
@@ -274,7 +279,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
||||
return self.tiled_decode(z, return_dict=return_dict)
|
||||
|
||||
z = self.post_quant_conv(z)
|
||||
if self.post_quant_conv is not None:
|
||||
z = self.post_quant_conv(z)
|
||||
|
||||
dec = self.decoder(z)
|
||||
|
||||
if not return_dict:
|
||||
@@ -283,7 +290,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
|
||||
def decode(
|
||||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
"""
|
||||
Decode a batch of images.
|
||||
|
||||
@@ -302,7 +311,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z, return_dict=False)[0]
|
||||
decoded = self._decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
@@ -333,13 +342,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
Args:
|
||||
x (`torch.Tensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a
|
||||
plain tuple.
|
||||
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned,
|
||||
otherwise a plain `tuple` is returned.
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
||||
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
||||
`tuple` is returned.
|
||||
"""
|
||||
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
||||
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
||||
|
||||
@@ -123,7 +123,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
"""2D Image to Patch Embedding with support for SD3 cropping."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -137,12 +137,14 @@ class PatchEmbed(nn.Module):
|
||||
bias=True,
|
||||
interpolation_scale=1,
|
||||
pos_embed_type="sincos",
|
||||
pos_embed_max_size=None, # For SD3 cropping
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_patches = (height // patch_size) * (width // patch_size)
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
@@ -153,26 +155,55 @@ class PatchEmbed(nn.Module):
|
||||
self.norm = None
|
||||
|
||||
self.patch_size = patch_size
|
||||
# See:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
|
||||
self.height, self.width = height // patch_size, width // patch_size
|
||||
self.base_size = height // patch_size
|
||||
self.interpolation_scale = interpolation_scale
|
||||
|
||||
# Calculate positional embeddings based on max size or default
|
||||
if pos_embed_max_size:
|
||||
grid_size = pos_embed_max_size
|
||||
else:
|
||||
grid_size = int(num_patches**0.5)
|
||||
|
||||
if pos_embed_type is None:
|
||||
self.pos_embed = None
|
||||
elif pos_embed_type == "sincos":
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim,
|
||||
int(num_patches**0.5),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
|
||||
)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
||||
persistent = True if pos_embed_max_size else False
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
|
||||
else:
|
||||
raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
"""Crops positional embeddings for SD3 compatibility."""
|
||||
if self.pos_embed_max_size is None:
|
||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
if height > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
if width > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||
if self.pos_embed_max_size is not None:
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
@@ -181,20 +212,20 @@ class PatchEmbed(nn.Module):
|
||||
latent = self.norm(latent)
|
||||
if self.pos_embed is None:
|
||||
return latent.to(latent.dtype)
|
||||
|
||||
# Interpolate positional embeddings if needed.
|
||||
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
|
||||
if self.height != height or self.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=self.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed)
|
||||
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
|
||||
# Interpolate or crop positional embeddings as needed
|
||||
if self.pos_embed_max_size:
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
if self.height != height or self.width != width:
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
embed_dim=self.pos_embed.shape[-1],
|
||||
grid_size=(height, width),
|
||||
base_size=self.base_size,
|
||||
interpolation_scale=self.interpolation_scale,
|
||||
)
|
||||
pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
|
||||
else:
|
||||
pos_embed = self.pos_embed
|
||||
|
||||
return (latent + pos_embed).to(latent.dtype)
|
||||
|
||||
@@ -626,6 +657,25 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
def forward(self, timestep, pooled_projection):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class HunyuanDiTAttentionPool(nn.Module):
|
||||
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
||||
|
||||
@@ -1001,6 +1051,8 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
||||
if act_fn == "gelu_tanh":
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
elif act_fn == "silu":
|
||||
self.act_1 = nn.SiLU()
|
||||
elif act_fn == "silu_fp32":
|
||||
self.act_1 = FP32SiLU()
|
||||
else:
|
||||
|
||||
@@ -57,10 +57,12 @@ class AdaLayerNormZero(nn.Module):
|
||||
num_embeddings (`int`): The size of the embeddings dictionary.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, num_embeddings: int):
|
||||
def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
if num_embeddings is not None:
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
else:
|
||||
self.emb = None
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
@@ -69,11 +71,14 @@ class AdaLayerNormZero(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
class_labels: torch.LongTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
class_labels: Optional[torch.LongTensor] = None,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
||||
if self.emb is not None:
|
||||
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
@@ -9,4 +9,5 @@ if is_torch_available():
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .t5_film_transformer import T5FilmDecoder
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
344
src/diffusers/models/transformers/transformer_sd3.py
Normal file
344
src/diffusers/models/transformers/transformer_sd3.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...models.attention import JointTransformerBlock
|
||||
from ...models.attention_processor import Attention, AttentionProcessor
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import AdaLayerNormContinuous
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
||||
from .transformer_2d import Transformer2DModelOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
The Transformer model introduced in Stable Diffusion 3.
|
||||
|
||||
Reference: https://arxiv.org/abs/2403.03206
|
||||
|
||||
Parameters:
|
||||
sample_size (`int`): The width of the latent images. This is fixed during training since
|
||||
it is used to learn a number of position embeddings.
|
||||
patch_size (`int`): Patch size to turn the input data into small patches.
|
||||
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
|
||||
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
|
||||
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
||||
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
|
||||
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
|
||||
out_channels (`int`, defaults to 16): Number of output channels.
|
||||
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
):
|
||||
super().__init__()
|
||||
default_out_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
||||
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=self.config.sample_size,
|
||||
width=self.config.sample_size,
|
||||
patch_size=self.config.patch_size,
|
||||
in_channels=self.config.in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
pos_embed_max_size=pos_embed_max_size, # hard-code for now.
|
||||
)
|
||||
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
|
||||
)
|
||||
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.config.caption_projection_dim)
|
||||
|
||||
# `attention_head_dim` is doubled to account for the mixing.
|
||||
# It needs to crafted when we get the actual checkpoints.
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=self.config.num_attention_heads,
|
||||
attention_head_dim=self.inner_dim,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
)
|
||||
for i in range(self.config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
||||
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
||||
"""
|
||||
Sets the attention processor to use [feed forward
|
||||
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
||||
|
||||
Parameters:
|
||||
chunk_size (`int`, *optional*):
|
||||
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
||||
over each tensor of dim=`dim`.
|
||||
dim (`int`, *optional*, defaults to `0`):
|
||||
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
||||
or dim=1 (sequence length).
|
||||
"""
|
||||
if dim not in [0, 1]:
|
||||
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
||||
|
||||
# By default chunk size is 1
|
||||
chunk_size = chunk_size or 1
|
||||
|
||||
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
||||
if hasattr(module, "set_chunk_feed_forward"):
|
||||
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_feed_forward(child, chunk_size, dim)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_feed_forward(module, chunk_size, dim)
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
||||
def fuse_qkv_projections(self):
|
||||
"""
|
||||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
||||
are fused. For cross-attention modules, key and value projection matrices are fused.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
self.original_attn_processors = None
|
||||
|
||||
for _, attn_processor in self.attn_processors.items():
|
||||
if "Added" in str(attn_processor.__class__.__name__):
|
||||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
||||
|
||||
self.original_attn_processors = self.attn_processors
|
||||
|
||||
for module in self.modules():
|
||||
if isinstance(module, Attention):
|
||||
module.fuse_projections(fuse=True)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
||||
def unfuse_qkv_projections(self):
|
||||
"""Disables the fused QKV projection if enabled.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This API is 🧪 experimental.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
if self.original_attn_processors is not None:
|
||||
self.set_attn_processor(self.original_attn_processors)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor = None,
|
||||
pooled_projections: torch.FloatTensor = None,
|
||||
timestep: torch.LongTensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`SD3Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
else:
|
||||
lora_scale = 1.0
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
else:
|
||||
logger.warning(
|
||||
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
||||
)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
|
||||
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
||||
temb = self.time_text_embed(timestep, pooled_projections)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
patch_size = self.config.patch_size
|
||||
height = height // patch_size
|
||||
width = width // patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
||||
)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -220,6 +220,7 @@ else:
|
||||
"StableDiffusionLDM3DPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_3"] = ["StableDiffusion3Pipeline", "StableDiffusion3Img2ImgPipeline"]
|
||||
_import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"]
|
||||
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
|
||||
_import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"]
|
||||
@@ -485,6 +486,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_3 import StableDiffusion3Img2ImgPipeline, StableDiffusion3Pipeline
|
||||
from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
|
||||
from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline
|
||||
from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline
|
||||
|
||||
52
src/diffusers/pipelines/stable_diffusion_3/__init__.py
Normal file
52
src/diffusers/pipelines/stable_diffusion_3/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["StableDiffusion3PipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_stable_diffusion_3"] = ["StableDiffusion3Pipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_3_img2img"] = ["StableDiffusion3Img2ImgPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
|
||||
from .pipeline_stable_diffusion_3_img2img import StableDiffusion3Img2ImgPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -0,0 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusion3PipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -0,0 +1,886 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import StableDiffusion3PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import StableDiffusion3Pipeline
|
||||
|
||||
>>> pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> prompt = "A cat holding a sign that says hello world"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("sd3.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
||||
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
||||
as its dimension.
|
||||
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
text_encoder_3 ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: SD3Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer") and self.transformer is not None
|
||||
else 128
|
||||
)
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if self.text_encoder_3 is None:
|
||||
return torch.zeros(
|
||||
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer_3(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = self.text_encoder_3.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
clip_model_index: int = 0,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
tokenizer = clip_tokenizers[clip_model_index]
|
||||
text_encoder = clip_text_encoders[clip_model_index]
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
||||
|
||||
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
|
||||
)
|
||||
|
||||
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
||||
negative_clip_prompt_embeds,
|
||||
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
||||
)
|
||||
|
||||
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
||||
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
height,
|
||||
width,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
joint_attention_kwargs=self.joint_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusion3PipelineOutput(images=image)
|
||||
@@ -0,0 +1,923 @@
|
||||
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import (
|
||||
CLIPTextModelWithProjection,
|
||||
CLIPTokenizer,
|
||||
T5EncoderModel,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
|
||||
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.transformers import SD3Transformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import StableDiffusion3PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
|
||||
>>> from diffusers import AutoPipelineForImage2Image
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> device = "cuda"
|
||||
>>> model_id_or_path = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
>>> pipe = AutoPipelineForImage2Image.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to(device)
|
||||
|
||||
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
>>> init_image = load_image(url).resize((512, 512))
|
||||
|
||||
>>> prompt = "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k"
|
||||
|
||||
>>> images = pipe(prompt=prompt, image=init_image, strength=0.95, guidance_scale=7.5).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Args:
|
||||
transformer ([`SD3Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant,
|
||||
with an additional added projection layer that is initialized with a diagonal matrix with the `hidden_size`
|
||||
as its dimension.
|
||||
text_encoder_2 ([`CLIPTextModelWithProjection`]):
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
|
||||
specifically the
|
||||
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
|
||||
variant.
|
||||
text_encoder_3 ([`T5EncoderModel`]):
|
||||
Frozen text-encoder. Stable Diffusion 3 uses
|
||||
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
|
||||
[t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_2 (`CLIPTokenizer`):
|
||||
Second Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
tokenizer_3 (`T5TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: SD3Transformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder_2: CLIPTextModelWithProjection,
|
||||
tokenizer_2: CLIPTokenizer,
|
||||
text_encoder_3: T5EncoderModel,
|
||||
tokenizer_3: T5TokenizerFast,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
text_encoder_2=text_encoder_2,
|
||||
text_encoder_3=text_encoder_3,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_2=tokenizer_2,
|
||||
tokenizer_3=tokenizer_3,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, vae_latent_channels=self.vae.config.latent_channels
|
||||
)
|
||||
self.tokenizer_max_length = self.tokenizer.model_max_length
|
||||
self.default_sample_size = self.transformer.config.sample_size
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
if self.text_encoder_3 is None:
|
||||
return torch.zeros(
|
||||
(batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer_3(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer_3.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder_3(text_input_ids.to(device))[0]
|
||||
|
||||
dtype = self.text_encoder_3.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds
|
||||
def _get_clip_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
clip_model_index: int = 0,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
clip_tokenizers = [self.tokenizer, self.tokenizer_2]
|
||||
clip_text_encoders = [self.text_encoder, self.text_encoder_2]
|
||||
|
||||
tokenizer = clip_tokenizers[clip_model_index]
|
||||
text_encoder = clip_text_encoders[clip_model_index]
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
batch_size = len(prompt)
|
||||
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
||||
)
|
||||
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
else:
|
||||
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
clip_skip: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to the `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
used in all text-encoders
|
||||
device: (`torch.device`):
|
||||
torch device
|
||||
num_images_per_prompt (`int`):
|
||||
number of images that should be generated per prompt
|
||||
do_classifier_free_guidance (`bool`):
|
||||
whether to use classifier free guidance or not
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used in both text-encoders
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
clip_skip (`int`, *optional*):
|
||||
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
||||
the output of the pre-final layer will be used for computing the prompt embeddings.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
prompt_embed, pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=0,
|
||||
)
|
||||
prompt_2_embed, pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
prompt=prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_model_index=1,
|
||||
)
|
||||
clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
|
||||
|
||||
t5_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=prompt_3,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
)
|
||||
|
||||
clip_prompt_embeds = torch.nn.functional.pad(
|
||||
clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=0,
|
||||
)
|
||||
negative_prompt_2_embed, negative_pooled_prompt_2_embed = self._get_clip_prompt_embeds(
|
||||
negative_prompt_2,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
clip_skip=None,
|
||||
clip_model_index=1,
|
||||
)
|
||||
negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
|
||||
|
||||
t5_negative_prompt_embed = self._get_t5_prompt_embeds(
|
||||
prompt=negative_prompt_3, num_images_per_prompt=num_images_per_prompt, device=device
|
||||
)
|
||||
|
||||
negative_clip_prompt_embeds = torch.nn.functional.pad(
|
||||
negative_clip_prompt_embeds,
|
||||
(0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
|
||||
)
|
||||
|
||||
negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
strength,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
||||
)
|
||||
|
||||
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
||||
)
|
||||
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
|
||||
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
||||
raise ValueError(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] == self.vae.config.latent_channels:
|
||||
init_latents = image
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
if image.shape[1] == self.vae.config.latent_channels:
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
elif isinstance(generator, list):
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(batch_size)
|
||||
]
|
||||
init_latents = torch.cat(init_latents, dim=0)
|
||||
else:
|
||||
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
init_latents = (init_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
|
||||
# expand init_latents for batch_size
|
||||
additional_image_per_prompt = batch_size // init_latents.shape[0]
|
||||
init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
else:
|
||||
init_latents = torch.cat([init_latents], dim=0)
|
||||
|
||||
shape = init_latents.shape
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
# get latents
|
||||
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
|
||||
latents = init_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def clip_skip(self):
|
||||
return self._clip_skip
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
strength: float = 0.6,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 7.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clip_skip: Optional[int] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
|
||||
will be used instead
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
negative_prompt_2 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
|
||||
`text_encoder_2`. If not defined, `negative_prompt` is used instead
|
||||
negative_prompt_3 (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
|
||||
`text_encoder_3`. If not defined, `negative_prompt` is used instead
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
||||
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
||||
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
|
||||
input argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
||||
of a plain tuple.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
strength,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._clip_skip = clip_skip
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
clip_skip=self.clip_skip,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 3. Preprocess image
|
||||
image = self.image_processor.preprocess(image)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
|
||||
|
||||
# 5. Prepare latent variables
|
||||
if latents is None:
|
||||
latents = self.prepare_latents(
|
||||
image,
|
||||
latent_timestep,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latent_model_input.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
pooled_projections=pooled_prompt_embeds,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_pooled_prompt_embeds = callback_outputs.pop(
|
||||
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return StableDiffusion3PipelineOutput(images=image)
|
||||
@@ -56,6 +56,7 @@ else:
|
||||
_import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"]
|
||||
_import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]
|
||||
_import_structure["scheduling_euler_discrete"] = ["EulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
|
||||
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
|
||||
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
|
||||
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
||||
@@ -151,6 +152,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_edm_euler import EDMEulerScheduler
|
||||
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
|
||||
from .scheduling_heun_discrete import HeunDiscreteScheduler
|
||||
from .scheduling_ipndm import IPNDMScheduler
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
|
||||
287
src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Normal file
287
src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput, logging
|
||||
from ..utils.torch_utils import randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's `step` function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Euler scheduler.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
The number of diffusion steps to train the model.
|
||||
timestep_spacing (`str`, defaults to `"linspace"`):
|
||||
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
||||
shift (`float`, defaults to 1.0):
|
||||
The shift value for the timestep schedule.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
shift: float = 1.0,
|
||||
):
|
||||
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
||||
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
||||
|
||||
sigmas = timesteps / num_train_timesteps
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
|
||||
self.timesteps = sigmas * num_train_timesteps
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
self.sigma_min = self.sigmas[-1].item()
|
||||
self.sigma_max = self.sigmas[0].item()
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def scale_noise(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
noise: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Foward process in flow-matching
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`:
|
||||
A scaled input sample.
|
||||
"""
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
sample = sigma * noise + (1.0 - sigma) * sample
|
||||
|
||||
return sample
|
||||
|
||||
def _sigma_to_t(self, sigma):
|
||||
return sigma * self.config.num_train_timesteps
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
timesteps = np.linspace(
|
||||
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
||||
)
|
||||
|
||||
sigmas = timesteps / self.config.num_train_timesteps
|
||||
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
||||
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
||||
|
||||
timesteps = sigmas * self.config.num_train_timesteps
|
||||
self.timesteps = timesteps.to(device=device)
|
||||
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# Upcast to avoid precision issues when computing prev_sample
|
||||
sample = sample.to(torch.float32)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
if gamma > 0:
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
# NOTE: "original_sample" should not be an expected prediction_type but is left in for
|
||||
# backwards compatibility
|
||||
|
||||
# if self.config.prediction_type == "vector_field":
|
||||
|
||||
denoised = sample - model_output * sigma
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - denoised) / sigma_hat
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
prev_sample = sample + derivative * dt
|
||||
# Cast sample back to model compatible dtype
|
||||
prev_sample = prev_sample.to(model_output.dtype)
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
@@ -242,6 +242,21 @@ class PriorTransformer(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class SD3Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class T2IAdapter(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1005,6 +1020,21 @@ class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class FlowMatchEulerDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -902,6 +902,36 @@ class StableCascadePriorPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3Img2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusion3Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionAdapterPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
207
tests/lora/test_lora_layers_sd3.py
Normal file
207
tests/lora/test_lora_layers_sd3.py
Normal file
@@ -0,0 +1,207 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Pipeline,
|
||||
)
|
||||
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@require_peft_backend
|
||||
class SD3LoRATests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
caption_projection_dim=32,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=4,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def get_lora_config_for_transformer(self):
|
||||
lora_config = LoraConfig(
|
||||
r=4,
|
||||
lora_alpha=4,
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
|
||||
init_lora_weights=False,
|
||||
use_dora=False,
|
||||
)
|
||||
return lora_config
|
||||
|
||||
def test_simple_inference_with_transformer_lora_save_load(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_config = self.get_lora_config_for_transformer()
|
||||
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
pipe.transformer.add_adapter(transformer_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora = pipe(**inputs).images
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
|
||||
|
||||
self.pipeline_class.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
transformer_lora_layers=transformer_state_dict,
|
||||
)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
images_lora_from_pretrained = pipe(**inputs).images
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
|
||||
"Loading from saved checkpoints should give same results.",
|
||||
)
|
||||
|
||||
def test_simple_inference_with_transformer_lora_and_scale(self):
|
||||
components = self.get_dummy_components()
|
||||
transformer_lora_config = self.get_lora_config_for_transformer()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_no_lora = pipe(**inputs).images
|
||||
|
||||
pipe.transformer.add_adapter(transformer_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora = pipe(**inputs).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
|
||||
self.assertTrue(
|
||||
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + scale should change the output",
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
|
||||
self.assertTrue(
|
||||
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
|
||||
"Lora + 0 scale should lead to same result as no LoRA",
|
||||
)
|
||||
78
tests/models/transformers/test_models_transformer_sd3.py
Normal file
78
tests/models/transformers/test_models_transformer_sd3.py
Normal file
@@ -0,0 +1,78 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import SD3Transformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = SD3Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = width = embedding_dim = 32
|
||||
pooled_embedding_dim = embedding_dim * 2
|
||||
sequence_length = 154
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 4,
|
||||
"caption_projection_dim": 32,
|
||||
"joint_attention_dim": 32,
|
||||
"pooled_projection_dim": 64,
|
||||
"out_channels": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
0
tests/pipelines/stable_diffusion_3/__init__.py
Normal file
0
tests/pipelines/stable_diffusion_3/__init__.py
Normal file
@@ -0,0 +1,271 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
|
||||
from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
params = frozenset(
|
||||
[
|
||||
"prompt",
|
||||
"height",
|
||||
"width",
|
||||
"guidance_scale",
|
||||
"negative_prompt",
|
||||
"prompt_embeds",
|
||||
"negative_prompt_embeds",
|
||||
]
|
||||
)
|
||||
batch_params = frozenset(["prompt", "negative_prompt"])
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
caption_projection_dim=32,
|
||||
joint_attention_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=4,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3
|
||||
), "Fusion of QKV projections shouldn't affect the outputs."
|
||||
assert np.allclose(
|
||||
image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3
|
||||
), "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
|
||||
assert np.allclose(
|
||||
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
|
||||
), "Original outputs should match when fused QKV projections are disabled."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Pipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
return {
|
||||
"prompt": "A photo of a cat",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
}
|
||||
|
||||
def test_sd3_inference(self):
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
[0.36132812, 0.30004883, 0.25830078],
|
||||
[0.36669922, 0.31103516, 0.23754883],
|
||||
[0.34814453, 0.29248047, 0.23583984],
|
||||
[0.35791016, 0.30981445, 0.23999023],
|
||||
[0.36328125, 0.31274414, 0.2607422],
|
||||
[0.37304688, 0.32177734, 0.26171875],
|
||||
[0.3671875, 0.31933594, 0.25756836],
|
||||
[0.36035156, 0.31103516, 0.2578125],
|
||||
[0.3857422, 0.33789062, 0.27563477],
|
||||
[0.3701172, 0.31982422, 0.265625],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4
|
||||
@@ -0,0 +1,258 @@
|
||||
import gc
|
||||
import random
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
SD3Transformer2DModel,
|
||||
StableDiffusion3Img2ImgPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
floats_tensor,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
from ..pipeline_params import (
|
||||
IMAGE_TO_IMAGE_IMAGE_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
|
||||
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
|
||||
)
|
||||
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
||||
|
||||
|
||||
class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
|
||||
required_optional_params = PipelineTesterMixin.required_optional_params
|
||||
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
|
||||
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = SD3Transformer2DModel(
|
||||
sample_size=32,
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=1,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=4,
|
||||
joint_attention_dim=32,
|
||||
caption_projection_dim=32,
|
||||
pooled_projection_dim=64,
|
||||
out_channels=4,
|
||||
)
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
torch.manual_seed(0)
|
||||
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
|
||||
|
||||
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0609,
|
||||
scaling_factor=1.5035,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"text_encoder_3": text_encoder_3,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
"tokenizer_3": tokenizer_3,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
|
||||
image = image / 2 + 0.5
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"image": image,
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"strength": 0.8,
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_3_img2img_different_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["prompt_2"] = "a different prompt"
|
||||
inputs["prompt_3"] = "another different prompt"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_img2img_different_negative_prompts(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
output_same_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
inputs["negative_prompt_2"] = "deformed"
|
||||
inputs["negative_prompt_3"] = "blurry"
|
||||
output_different_prompts = pipe(**inputs).images[0]
|
||||
|
||||
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
|
||||
|
||||
# Outputs should be different here
|
||||
assert max_diff > 1e-2
|
||||
|
||||
def test_stable_diffusion_3_img2img_prompt_embeds(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
output_with_prompt = pipe(**inputs).images[0]
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
prompt = inputs.pop("prompt")
|
||||
|
||||
do_classifier_free_guidance = inputs["guidance_scale"] > 1
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = pipe.encode_prompt(
|
||||
prompt,
|
||||
prompt_2=None,
|
||||
prompt_3=None,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
device=torch_device,
|
||||
)
|
||||
output_with_embeds = pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
**inputs,
|
||||
).images[0]
|
||||
|
||||
max_diff = np.abs(output_with_prompt - output_with_embeds).max()
|
||||
assert max_diff < 1e-4
|
||||
|
||||
def test_multi_vae(self):
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
|
||||
pipeline_class = StableDiffusion3Img2ImgPipeline
|
||||
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_inputs(self, device, seed=0):
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
|
||||
"/stable_diffusion_img2img/sketch-mountains-input.png"
|
||||
)
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
return {
|
||||
"prompt": "A photo of a cat",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 5.0,
|
||||
"output_type": "np",
|
||||
"generator": generator,
|
||||
"image": init_image,
|
||||
}
|
||||
|
||||
def test_sd3_img2img_inference(self):
|
||||
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
inputs = self.get_inputs(torch_device)
|
||||
|
||||
image = pipe(**inputs).images[0]
|
||||
image_slice = image[0, :10, :10]
|
||||
expected_slice = np.array(
|
||||
[
|
||||
[0.50097656, 0.44726562, 0.40429688],
|
||||
[0.5048828, 0.45703125, 0.38110352],
|
||||
[0.4987793, 0.45141602, 0.38134766],
|
||||
[0.49682617, 0.45336914, 0.38354492],
|
||||
[0.49804688, 0.4555664, 0.39379883],
|
||||
[0.5083008, 0.4645996, 0.40039062],
|
||||
[0.50341797, 0.46240234, 0.39770508],
|
||||
[0.49926758, 0.4572754, 0.39575195],
|
||||
[0.50634766, 0.46435547, 0.39794922],
|
||||
[0.50341797, 0.4572754, 0.39746094],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
|
||||
|
||||
assert max_diff < 1e-4, f"Outputs are not close enough, got {image_slice}"
|
||||
Reference in New Issue
Block a user