-
-

-
"A photo of a banana-shaped couch in a living room"
+
+

+
A cute cat lounges on a leaf in a pool during a peaceful summer afternoon, in lofi art style, illustration.
-
-

-
"A vibrant yellow banana-shaped couch sits in a cozy living room, its curve cradling a pile of colorful cushions. on the wooden floor, a patterned rug adds a touch of eclectic charm, and a potted plant sits in the corner, reaching towards the sunlight filtering through the windows"
+
+

+
A cute cat lounges on a floating leaf in a sparkling pool during a peaceful summer afternoon. Clear reflections ripple across the water, with sunlight casting soft, smooth highlights. The illustration is detailed and polished, with elegant lines and harmonious colors, evoking a relaxing, serene, and whimsical lofi mood, anime-inspired and visually comforting.
-## Prompt enhancing with GPT2
-
-Prompt enhancing is a technique for quickly improving prompt quality without spending too much effort constructing one. It uses a model like GPT2 pretrained on Stable Diffusion text prompts to automatically enrich a prompt with additional important keywords to generate high-quality images.
-
-The technique works by curating a list of specific keywords and forcing the model to generate those words to enhance the original prompt. This way, your prompt can be "a cat" and GPT2 can enhance the prompt to "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic".
+Be specific and add context. Use photography terms like lens type, focal length, camera angles, and depth of field.
> [!TIP]
-> You should also use a [*offset noise*](https://www.crosslabs.org//blog/diffusion-with-offset-noise) LoRA to improve the contrast in bright and dark images and create better lighting overall. This [LoRA](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors) is available from [stabilityai/stable-diffusion-xl-base-1.0](https://hf.co/stabilityai/stable-diffusion-xl-base-1.0).
-
-Start by defining certain styles and a list of words (you can check out a more comprehensive list of [words](https://hf.co/LykosAI/GPT-Prompt-Expansion-Fooocus-v2/blob/main/positive.txt) and [styles](https://github.com/lllyasviel/Fooocus/tree/main/sdxl_styles) used by Fooocus) to enhance a prompt with.
-
-```py
-import torch
-from transformers import GenerationConfig, GPT2LMHeadModel, GPT2Tokenizer, LogitsProcessor, LogitsProcessorList
-from diffusers import StableDiffusionXLPipeline
-
-styles = {
- "cinematic": "cinematic film still of {prompt}, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
- "anime": "anime artwork of {prompt}, anime style, key visual, vibrant, studio anime, highly detailed",
- "photographic": "cinematic photo of {prompt}, 35mm photograph, film, professional, 4k, highly detailed",
- "comic": "comic of {prompt}, graphic illustration, comic art, graphic novel art, vibrant, highly detailed",
- "lineart": "line art drawing {prompt}, professional, sleek, modern, minimalist, graphic, line art, vector graphics",
- "pixelart": " pixel-art {prompt}, low-res, blocky, pixel art style, 8-bit graphics",
-}
-
-words = [
- "aesthetic", "astonishing", "beautiful", "breathtaking", "composition", "contrasted", "epic", "moody", "enhanced",
- "exceptional", "fascinating", "flawless", "glamorous", "glorious", "illumination", "impressive", "improved",
- "inspirational", "magnificent", "majestic", "hyperrealistic", "smooth", "sharp", "focus", "stunning", "detailed",
- "intricate", "dramatic", "high", "quality", "perfect", "light", "ultra", "highly", "radiant", "satisfying",
- "soothing", "sophisticated", "stylish", "sublime", "terrific", "touching", "timeless", "wonderful", "unbelievable",
- "elegant", "awesome", "amazing", "dynamic", "trendy",
-]
-```
-
-You may have noticed in the `words` list, there are certain words that can be paired together to create something more meaningful. For example, the words "high" and "quality" can be combined to create "high quality". Let's pair these words together and remove the words that can't be paired.
-
-```py
-word_pairs = ["highly detailed", "high quality", "enhanced quality", "perfect composition", "dynamic light"]
-
-def find_and_order_pairs(s, pairs):
- words = s.split()
- found_pairs = []
- for pair in pairs:
- pair_words = pair.split()
- if pair_words[0] in words and pair_words[1] in words:
- found_pairs.append(pair)
- words.remove(pair_words[0])
- words.remove(pair_words[1])
-
- for word in words[:]:
- for pair in pairs:
- if word in pair.split():
- words.remove(word)
- break
- ordered_pairs = ", ".join(found_pairs)
- remaining_s = ", ".join(words)
- return ordered_pairs, remaining_s
-```
-
-Next, implement a custom [`~transformers.LogitsProcessor`] class that assigns tokens in the `words` list a value of 0 and assigns tokens not in the `words` list a negative value so they aren't picked during generation. This way, generation is biased towards words in the `words` list. After a word from the list is used, it is also assigned a negative value so it isn't picked again.
-
-```py
-class CustomLogitsProcessor(LogitsProcessor):
- def __init__(self, bias):
- super().__init__()
- self.bias = bias
-
- def __call__(self, input_ids, scores):
- if len(input_ids.shape) == 2:
- last_token_id = input_ids[0, -1]
- self.bias[last_token_id] = -1e10
- return scores + self.bias
-
-word_ids = [tokenizer.encode(word, add_prefix_space=True)[0] for word in words]
-bias = torch.full((tokenizer.vocab_size,), -float("Inf")).to("cuda")
-bias[word_ids] = 0
-processor = CustomLogitsProcessor(bias)
-processor_list = LogitsProcessorList([processor])
-```
-
-Combine the prompt and the `cinematic` style prompt defined in the `styles` dictionary earlier.
-
-```py
-prompt = "a cat basking in the sun on a roof in Turkey"
-style = "cinematic"
-
-prompt = styles[style].format(prompt=prompt)
-prompt
-"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
-```
-
-Load a GPT2 tokenizer and model from the [Gustavosta/MagicPrompt-Stable-Diffusion](https://huggingface.co/Gustavosta/MagicPrompt-Stable-Diffusion) checkpoint (this specific checkpoint is trained to generate prompts) to enhance the prompt.
-
-```py
-tokenizer = GPT2Tokenizer.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion")
-model = GPT2LMHeadModel.from_pretrained("Gustavosta/MagicPrompt-Stable-Diffusion", torch_dtype=torch.float16).to(
- "cuda"
-)
-model.eval()
-
-inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
-token_count = inputs["input_ids"].shape[1]
-max_new_tokens = 50 - token_count
-
-generation_config = GenerationConfig(
- penalty_alpha=0.7,
- top_k=50,
- eos_token_id=model.config.eos_token_id,
- pad_token_id=model.config.eos_token_id,
- pad_token=model.config.pad_token_id,
- do_sample=True,
-)
-
-with torch.no_grad():
- generated_ids = model.generate(
- input_ids=inputs["input_ids"],
- attention_mask=inputs["attention_mask"],
- max_new_tokens=max_new_tokens,
- generation_config=generation_config,
- logits_processor=proccesor_list,
- )
-```
-
-Then you can combine the input prompt and the generated prompt. Feel free to take a look at what the generated prompt (`generated_part`) is, the word pairs that were found (`pairs`), and the remaining words (`words`). This is all packed together in the `enhanced_prompt`.
-
-```py
-output_tokens = [tokenizer.decode(generated_id, skip_special_tokens=True) for generated_id in generated_ids]
-input_part, generated_part = output_tokens[0][: len(prompt)], output_tokens[0][len(prompt) :]
-pairs, words = find_and_order_pairs(generated_part, word_pairs)
-formatted_generated_part = pairs + ", " + words
-enhanced_prompt = input_part + ", " + formatted_generated_part
-enhanced_prompt
-["cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain quality sharp focus beautiful detailed intricate stunning amazing epic"]
-```
-
-Finally, load a pipeline and the offset noise LoRA with a *low weight* to generate an image with the enhanced prompt.
-
-```py
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "RunDiffusion/Juggernaut-XL-v9", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-pipeline.load_lora_weights(
- "stabilityai/stable-diffusion-xl-base-1.0",
- weight_name="sd_xl_offset_example-lora_1.0.safetensors",
- adapter_name="offset",
-)
-pipeline.set_adapters(["offset"], adapter_weights=[0.2])
-
-image = pipeline(
- enhanced_prompt,
- width=1152,
- height=896,
- guidance_scale=7.5,
- num_inference_steps=25,
-).images[0]
-image
-```
-
-
-
-

-
"a cat basking in the sun on a roof in Turkey"
-
-
-

-
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain"
-
-
+> Try a [prompt enhancer](https://huggingface.co/models?sort=downloads&search=prompt+enhancer) to help improve your prompt structure.
## Prompt weighting
-Prompt weighting provides a way to emphasize or de-emphasize certain parts of a prompt, allowing for more control over the generated image. A prompt can include several concepts, which gets turned into contextualized text embeddings. The embeddings are used by the model to condition its cross-attention layers to generate an image (read the Stable Diffusion [blog post](https://huggingface.co/blog/stable_diffusion) to learn more about how it works).
+Prompt weighting makes some words stronger and others weaker. It scales attention scores so you control how much influence each concept has.
-Prompt weighting works by increasing or decreasing the scale of the text embedding vector that corresponds to its concept in the prompt because you may not necessarily want the model to focus on all concepts equally. The easiest way to prepare the prompt embeddings is to use [Stable Diffusion Long Prompt Weighted Embedding](https://github.com/xhinker/sd_embed) (sd_embed). Once you have the prompt-weighted embeddings, you can pass them to any pipeline that has a [prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.prompt_embeds) (and optionally [negative_prompt_embeds](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__.negative_prompt_embeds)) parameter, such as [`StableDiffusionPipeline`], [`StableDiffusionControlNetPipeline`], and [`StableDiffusionXLPipeline`].
+Diffusers handles this through `prompt_embeds` and `pooled_prompt_embeds` arguments which take scaled text embedding vectors. Use the [sd_embed](https://github.com/xhinker/sd_embed) library to generate these embeddings. It also supports longer prompts.
-
-
-If your favorite pipeline doesn't have a `prompt_embeds` parameter, please open an [issue](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!
-
-
-
-This guide will show you how to weight your prompts with sd_embed.
-
-Before you begin, make sure you have the latest version of sd_embed installed:
-
-```bash
-pip install git+https://github.com/xhinker/sd_embed.git@main
-```
-
-For this example, let's use [`StableDiffusionXLPipeline`].
+> [!NOTE]
+> The sd_embed library only supports Stable Diffusion, Stable Diffusion XL, Stable Diffusion 3, Stable Cascade, and Flux. Prompt weighting doesn't necessarily help for newer models like Flux which already has very good prompt adherence.
```py
-from diffusers import StableDiffusionXLPipeline, UniPCMultistepScheduler
-import torch
-
-pipe = StableDiffusionXLPipeline.from_pretrained("Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16)
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-pipe.to("cuda")
+!uv pip install git+https://github.com/xhinker/sd_embed.git@main
```
-To upweight or downweight a concept, surround the text with parentheses. More parentheses applies a heavier weight on the text. You can also append a numerical multiplier to the text to indicate how much you want to increase or decrease its weights by.
+Format weighted text with numerical multipliers or parentheses. More parentheses mean stronger weighting.
| format | multiplier |
|---|---|
-| `(hippo)` | increase by 1.1x |
-| `((hippo))` | increase by 1.21x |
-| `(hippo:1.5)` | increase by 1.5x |
-| `(hippo:0.5)` | decrease by 4x |
+| `(cat)` | increase by 1.1x |
+| `((cat))` | increase by 1.21x |
+| `(cat:1.5)` | increase by 1.5x |
+| `(cat:0.5)` | decrease by 4x |
-Create a prompt and use a combination of parentheses and numerical multipliers to upweight various text.
+Create a weighted prompt and pass it to [get_weighted_text_embeddings_sdxl](https://github.com/xhinker/sd_embed/blob/4a47f71150a22942fa606fb741a1c971d95ba56f/src/sd_embed/embedding_funcs.py#L405) to generate embeddings.
+
+> [!TIP]
+> You could also pass negative prompts to `negative_prompt_embeds` and `negative_pooled_prompt_embeds`.
```py
+import torch
+from diffusers import DiffusionPipeline
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl
-prompt = """A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-```
-
-Use the `get_weighted_text_embeddings_sdxl` function to generate the prompt embeddings and the negative prompt embeddings. It'll also generated the pooled and negative pooled prompt embeddings since you're using the SDXL model.
-
-> [!TIP]
-> You can safely ignore the error message below about the token index length exceeding the models maximum sequence length. All your tokens will be used in the embedding process.
->
-> ```
-> Token indices sequence length is longer than the specified maximum sequence length for this model
-> ```
-
-```py
-(
- prompt_embeds,
- prompt_neg_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds
-) = get_weighted_text_embeddings_sdxl(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
+pipeline = DiffusionPipeline.from_pretrained(
+ "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.bfloat16, device_map="cuda"
)
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- pooled_prompt_embeds=pooled_prompt_embeds,
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
- num_inference_steps=30,
- height=1024,
- width=1024 + 512,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
+prompt = """
+A (cute cat:1.4) lounges on a (floating leaf:1.2) in a (sparkling pool:1.1) during a peaceful summer afternoon.
+Gentle ripples reflect pastel skies, while (sunlight:1.1) casts soft highlights. The illustration is smooth and polished
+with elegant, sketchy lines and subtle gradients, evoking a ((whimsical, nostalgic, dreamy lofi atmosphere:2.0)),
+(anime-inspired:1.6), calming, comforting, and visually serene.
+"""
+
+prompt_embeds, _, pooled_prompt_embeds, *_ = get_weighted_text_embeddings_sdxl(pipeline, prompt=prompt)
+```
+
+Pass the embeddings to `prompt_embeds` and `pooled_prompt_embeds` to generate your image.
+
+```py
+image = pipeline(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds).images[0]
```
-

+
-> [!TIP]
-> Refer to the [sd_embed](https://github.com/xhinker/sd_embed) repository for additional details about long prompt weighting for FLUX.1, Stable Cascade, and Stable Diffusion 1.5.
-
-### Textual inversion
-
-[Textual inversion](../training/text_inversion) is a technique for learning a specific concept from some images which you can use to generate new images conditioned on that concept.
-
-Create a pipeline and use the [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] function to load the textual inversion embeddings (feel free to browse the [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer) for 100+ trained concepts):
-
-```py
-import torch
-from diffusers import StableDiffusionPipeline
-
-pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5",
- torch_dtype=torch.float16,
-).to("cuda")
-pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
-```
-
-Add the `
` text to the prompt to trigger the textual inversion.
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """ A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-```
-
-Use the `get_weighted_text_embeddings_sd15` function to generate the prompt embeddings and the negative prompt embeddings.
-
-```py
-(
- prompt_embeds,
- prompt_neg_embeds,
-) = get_weighted_text_embeddings_sd15(
- pipe,
- prompt=prompt,
- neg_prompt=neg_prompt
-)
-
-image = pipe(
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=prompt_neg_embeds,
- height=768,
- width=896,
- guidance_scale=4.0,
- generator=torch.Generator("cuda").manual_seed(2)
-).images[0]
-image
-```
-
-
-

-
-
-### DreamBooth
-
-[DreamBooth](../training/dreambooth) is a technique for generating contextualized images of a subject given just a few images of the subject to train on. It is similar to textual inversion, but DreamBooth trains the full model whereas textual inversion only fine-tunes the text embeddings. This means you should use [`~DiffusionPipeline.from_pretrained`] to load the DreamBooth model (feel free to browse the [Stable Diffusion Dreambooth Concepts Library](https://huggingface.co/sd-dreambooth-library) for 100+ trained models):
-
-```py
-import torch
-from diffusers import DiffusionPipeline, UniPCMultistepScheduler
-
-pipe = DiffusionPipeline.from_pretrained("sd-dreambooth-library/dndcoverart-v1", torch_dtype=torch.float16).to("cuda")
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
-```
-
-Depending on the model you use, you'll need to incorporate the model's unique identifier into your prompt. For example, the `dndcoverart-v1` model uses the identifier `dndcoverart`:
-
-```py
-from sd_embed.embedding_funcs import get_weighted_text_embeddings_sd15
-
-prompt = """dndcoverart of A whimsical and creative image depicting a hybrid creature that is a mix of a waffle and a hippopotamus.
-This imaginative creature features the distinctive, bulky body of a hippo,
-but with a texture and appearance resembling a golden-brown, crispy waffle.
-The creature might have elements like waffle squares across its skin and a syrup-like sheen.
-It's set in a surreal environment that playfully combines a natural water habitat of a hippo with elements of a breakfast table setting,
-possibly including oversized utensils or plates in the background.
-The image should evoke a sense of playful absurdity and culinary fantasy.
-"""
-
-neg_prompt = """\
-skin spots,acnes,skin blemishes,age spot,(ugly:1.2),(duplicate:1.2),(morbid:1.21),(mutilated:1.2),\
-(tranny:1.2),mutated hands,(poorly drawn hands:1.5),blurry,(bad anatomy:1.2),(bad proportions:1.3),\
-extra limbs,(disfigured:1.2),(missing arms:1.2),(extra legs:1.2),(fused fingers:1.5),\
-(too many fingers:1.5),(unclear eyes:1.2),lowers,bad hands,missing fingers,extra digit,\
-bad hands,missing fingers,(extra arms and legs),(worst quality:2),(low quality:2),\
-(normal quality:2),lowres,((monochrome)),((grayscale))
-"""
-
-(
- prompt_embeds
- , prompt_neg_embeds
-) = get_weighted_text_embeddings_sd15(
- pipe
- , prompt = prompt
- , neg_prompt = neg_prompt
-)
-```
-
-
-

-
+Prompt weighting works with [Textual inversion](./textual_inversion_inference) and [DreamBooth](./dreambooth) adapters too.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md
index 15a7e8dc7c..e34727b5da 100644
--- a/docs/source/en/using-diffusers/write_own_pipeline.md
+++ b/docs/source/en/using-diffusers/write_own_pipeline.md
@@ -110,11 +110,8 @@ Stable Diffusion is a text-to-image *latent diffusion* model. It is called a lat
As you can see, this is already more complex than the DDPM pipeline which only contains a UNet model. The Stable Diffusion model has three separate pretrained models.
-
-
-💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
-
-
+> [!TIP]
+> 💡 Read the [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) blog for more details about how the VAE, UNet, and text encoder models work.
Now that you know what you need for the Stable Diffusion pipeline, load all these components with the [`~ModelMixin.from_pretrained`] method. You can find them in the pretrained [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint, and each component is stored in a separate subfolder:
@@ -155,11 +152,8 @@ To speed up inference, move the models to a GPU since, unlike the scheduler, the
The next step is to tokenize the text to generate embeddings. The text is used to condition the UNet model and steer the diffusion process towards something that resembles the input prompt.
-
-
-💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.
-
-
+> [!TIP]
+> 💡 The `guidance_scale` parameter determines how much weight should be given to the prompt when generating an image.
Feel free to choose any prompt you like if you want to generate something else!
@@ -202,15 +196,12 @@ Let's concatenate the conditional and unconditional embeddings into a batch to a
Next, generate some initial random noise as a starting point for the diffusion process. This is the latent representation of the image, and it'll be gradually denoised. At this point, the `latent` image is smaller than the final image size but that's okay though because the model will transform it into the final 512x512 image dimensions later.
-
-
-💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:
-
-```py
-2 ** (len(vae.config.block_out_channels) - 1) == 8
-```
-
-
+> [!TIP]
+> 💡 The height and width are divided by 8 because the `vae` model has 3 down-sampling layers. You can check by running the following:
+>
+> ```py
+> 2 ** (len(vae.config.block_out_channels) - 1) == 8
+> ```
```py
>>> latents = torch.randn(
@@ -289,5 +280,5 @@ This is really what 🧨 Diffusers is designed for: to make it intuitive and eas
For your next steps, feel free to:
-* Learn how to [build and contribute a pipeline](../using-diffusers/contribute_pipeline) to 🧨 Diffusers. We can't wait and see what you'll come up with!
+* Learn how to [build and contribute a pipeline](../conceptual/contribution) to 🧨 Diffusers. We can't wait and see what you'll come up with!
* Explore [existing pipelines](../api/pipelines/overview) in the library, and see if you can deconstruct and build a pipeline from scratch using the models and schedulers separately.
diff --git a/docs/source/ja/installation.md b/docs/source/ja/installation.md
index 97d60528c4..fd6f4eda0f 100644
--- a/docs/source/ja/installation.md
+++ b/docs/source/ja/installation.md
@@ -108,11 +108,8 @@ pip install -e ".[flax]"
Python は通常のライブラリパスに加えて、クローンしたフォルダの中を探すようになります。
例えば、Python パッケージが通常 `~/anaconda3/envs/main/lib/python3.10/site-packages/` にインストールされている場合、Python はクローンした `~/diffusers/` フォルダも同様に参照します。
-
-
-ライブラリを使い続けたい場合は、`diffusers`フォルダを残しておく必要があります。
-
-
+> [!WARNING]
+> ライブラリを使い続けたい場合は、`diffusers`フォルダを残しておく必要があります。
これで、以下のコマンドで簡単にクローンを最新版の🤗 Diffusersにアップデートできます:
diff --git a/docs/source/ja/quicktour.md b/docs/source/ja/quicktour.md
index 03b340b352..ce88aaf7b5 100644
--- a/docs/source/ja/quicktour.md
+++ b/docs/source/ja/quicktour.md
@@ -24,11 +24,8 @@ specific language governing permissions and limitations under the License.
この案内では、[`DiffusionPipeline`]を生成に使用する方法を紹介し、モデルとスケジューラを組み合わせて[`DiffusionPipeline`]の内部で起こっていることを再現する方法を説明します。
-
-
-この案内は🧨 Diffusers [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)を簡略化したもので、すぐに使い始めることができます。Diffusers 🧨のゴール、設計哲学、コアAPIの詳細についてもっと知りたい方は、ノートブックをご覧ください!
-
-
+> [!TIP]
+> この案内は🧨 Diffusers [ノートブック](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)を簡略化したもので、すぐに使い始めることができます。Diffusers 🧨のゴール、設計哲学、コアAPIの詳細についてもっと知りたい方は、ノートブックをご覧ください!
始める前に必要なライブラリーがすべてインストールされていることを確認してください:
@@ -56,11 +53,8 @@ specific language governing permissions and limitations under the License.
この[`DiffusionPipeline`]はHugging Face Hubに保存されている任意の[チェックポイント](https://huggingface.co/models?library=diffusers&sort=downloads)を使用することができます。
この案内では、[`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)チェックポイントでテキストから画像へ生成します。
-
-
-[Stable Diffusion]モデルについては、モデルを実行する前にまず[ライセンス](https://huggingface.co/spaces/CompVis/stable-diffusion-license)を注意深くお読みください。🧨 Diffusers は、攻撃的または有害なコンテンツを防ぐために [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) を実装していますが、モデルの改良された画像生成機能により、潜在的に有害なコンテンツが生成される可能性があります。
-
-
+> [!WARNING]
+> [Stable Diffusion]モデルについては、モデルを実行する前にまず[ライセンス](https://huggingface.co/spaces/CompVis/stable-diffusion-license)を注意深くお読みください。🧨 Diffusers は、攻撃的または有害なコンテンツを防ぐために [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) を実装していますが、モデルの改良された画像生成機能により、潜在的に有害なコンテンツが生成される可能性があります。
モデルを[`~DiffusionPipeline.from_pretrained`]メソッドでロードします:
@@ -204,11 +198,8 @@ torch.Size([1, 3, 256, 256])
スケジューラは、モデルの出力(この場合は `noisy_residual` )が与えられたときに、ノイズの多いサンプルからノイズの少ないサンプルへの移行を管理します。
-
-
-🧨 Diffusersは拡散システムを構築するためのツールボックスです。[`DiffusionPipeline`]は事前に構築された拡散システムを使い始めるのに便利な方法ですが、独自のモデルとスケジューラコンポーネントを個別に選択してカスタム拡散システムを構築することもできます。
-
-
+> [!TIP]
+> 🧨 Diffusersは拡散システムを構築するためのツールボックスです。[`DiffusionPipeline`]は事前に構築された拡散システムを使い始めるのに便利な方法ですが、独自のモデルとスケジューラコンポーネントを個別に選択してカスタム拡散システムを構築することもできます。
この案内では、[`DDPMScheduler`]を[`~diffusers.ConfigMixin.from_config`]メソッドでインスタンス化します:
@@ -232,11 +223,8 @@ DDPMScheduler {
}
```
-
-
-💡 スケジューラがどのようにコンフィギュレーションからインスタンス化されるかに注目してください。モデルとは異なり、スケジューラは学習可能な重みを持たず、パラメーターを持ちません!
-
-
+> [!TIP]
+> 💡 スケジューラがどのようにコンフィギュレーションからインスタンス化されるかに注目してください。モデルとは異なり、スケジューラは学習可能な重みを持たず、パラメーターを持ちません!
最も重要なパラメータは以下の通りです:
diff --git a/docs/source/ja/stable_diffusion.md b/docs/source/ja/stable_diffusion.md
index 85f2b38a7d..79abfa005d 100644
--- a/docs/source/ja/stable_diffusion.md
+++ b/docs/source/ja/stable_diffusion.md
@@ -37,11 +37,8 @@ prompt = "portrait photo of a old warrior chief"
## Speed
-
-
-💡 GPUを利用できない場合は、[Colab](https://colab.research.google.com/)のようなGPUプロバイダーから無料で利用できます!
-
-
+> [!TIP]
+> 💡 GPUを利用できない場合は、[Colab](https://colab.research.google.com/)のようなGPUプロバイダーから無料で利用できます!
画像生成を高速化する最も簡単な方法の1つは、PyTorchモジュールと同じようにGPU上にパイプラインを配置することです:
@@ -88,11 +85,8 @@ image
今回、画像生成にかかった時間はわずか11秒で、以前より3倍近く速くなりました!
-
-
-💡 パイプラインは常に `float16` で実行することを強くお勧めします。
-
-
+> [!TIP]
+> 💡 パイプラインは常に `float16` で実行することを強くお勧めします。
生成ステップ数を減らすという方法もあります。より効率的なスケジューラを選択することで、出力品質を犠牲にすることなくステップ数を減らすことができます。`compatibles`メソッドを呼び出すことで、[`DiffusionPipeline`]の現在のモデルと互換性のあるスケジューラを見つけることができます:
diff --git a/docs/source/ja/tutorials/autopipeline.md b/docs/source/ja/tutorials/autopipeline.md
index a9a780186a..7dc678da90 100644
--- a/docs/source/ja/tutorials/autopipeline.md
+++ b/docs/source/ja/tutorials/autopipeline.md
@@ -16,11 +16,8 @@ Diffusersは様々なタスクをこなすことができ、テキストから
`AutoPipeline` クラスは、🤗 Diffusers の様々なパイプラインをよりシンプルするために設計されています。この汎用的でタスク重視のパイプラインによってタスクそのものに集中することができます。`AutoPipeline` は、使用するべき正しいパイプラインクラスを自動的に検出するため、特定のパイプラインクラス名を知らなくても、タスクのチェックポイントを簡単にロードできます。
-
-
-どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
-
-
+> [!TIP]
+> どのタスクがサポートされているかは、[AutoPipeline](../api/pipelines/auto_pipeline) のリファレンスをご覧ください。現在、text-to-image、image-to-image、inpaintingをサポートしています。
このチュートリアルでは、`AutoPipeline` を使用して、事前に学習された重みが与えられたときに、特定のタスクを読み込むためのパイプラインクラスを自動的に推測する方法を示します。
diff --git a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
index 34a00d63fe..ba85b4a855 100644
--- a/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
+++ b/docs/source/ko/api/pipelines/stable_diffusion/stable_diffusion_xl.md
@@ -207,11 +207,8 @@ image = refiner(
동일한 40 단계에서 base 모델을 실행한다면, 이미지의 디테일(예: 사자의 눈과 코)이 떨어졌을 것입니다:
-
-
-앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!
-
-
+> [!TIP]
+> 앙상블 방식은 사용 가능한 모든 스케줄러에서 잘 작동합니다!
#### 2.) 노이즈가 완전히 제거된 기본 이미지에서 이미지 출력을 정제하기
@@ -248,11 +245,8 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
|---|---|
|  |  |
-
-
-refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.
-
-
+> [!TIP]
+> refiner는 또한 인페인팅 설정에 잘 사용될 수 있습니다. 아래에 보여지듯이 [`StableDiffusionXLInpaintPipeline`] 클래스를 사용해서 만들어보세요.
Denoiser 앙상블 설정에서 인페인팅에 refiner를 사용하려면 다음을 수행하면 됩니다:
diff --git a/docs/source/ko/conceptual/ethical_guidelines.md b/docs/source/ko/conceptual/ethical_guidelines.md
index b8c55048bf..63fc4a7741 100644
--- a/docs/source/ko/conceptual/ethical_guidelines.md
+++ b/docs/source/ko/conceptual/ethical_guidelines.md
@@ -14,51 +14,47 @@ specific language governing permissions and limitations under the License.
## 서문 [[preamble]]
-[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며 추론 및 훈련을 위한 모듈식 툴박스로 사용됩니다.
+[Diffusers](https://huggingface.co/docs/diffusers/index)는 사전 훈련된 diffusion 모델을 제공하며, 추론과 훈련을 위한 모듈형 툴박스로 활용됩니다.
-이 기술의 실제 적용과 사회에 미칠 수 있는 부정적인 영향을 고려하여 Diffusers 라이브러리의 개발, 사용자 기여 및 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
-
-이이 기술을 사용함에 따른 위험은 여전히 검토 중이지만, 몇 가지 예를 들면: 예술가들에 대한 저작권 문제; 딥 페이크의 악용; 부적절한 맥락에서의 성적 콘텐츠 생성; 동의 없는 사칭; 소수자 집단의 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
-
-우리는 위험을 지속적으로 추적하고 커뮤니티의 응답과 소중한 피드백에 따라 다음 지침을 조정할 것입니다.
+이 기술의 실제 적용 사례와 사회에 미칠 수 있는 잠재적 부정적 영향을 고려할 때, Diffusers 라이브러리의 개발, 사용자 기여, 사용에 윤리 지침을 제공하는 것이 중요하다고 생각합니다.
+이 기술 사용과 관련된 위험은 여전히 검토 중이지만, 예를 들면: 예술가의 저작권 문제, 딥페이크 악용, 부적절한 맥락에서의 성적 콘텐츠 생성, 비동의 사칭, 소수자 집단 억압을 영속화하는 유해한 사회적 편견 등이 있습니다.
+우리는 이러한 위험을 지속적으로 추적하고, 커뮤니티의 반응과 소중한 피드백에 따라 아래 지침을 조정할 것입니다.
## 범위 [[scope]]
-Diffusers 커뮤니티는 프로젝트의 개발에 다음과 같은 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대한 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
-
+Diffusers 커뮤니티는 프로젝트 개발에 다음 윤리 지침을 적용하며, 특히 윤리적 문제와 관련된 민감한 주제에 대해 커뮤니티의 기여를 조정하는 데 도움을 줄 것입니다.
## 윤리 지침 [[ethical-guidelines]]
-다음 윤리 지침은 일반적으로 적용되지만, 민감한 윤리적 문제와 관련하여 기술적 선택을 할 때 이를 우선적으로 적용할 것입니다. 나아가, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 조정할 것을 약속드립니다.
+다음 윤리 지침은 일반적으로 적용되지만, 윤리적으로 민감한 문제와 관련된 기술적 선택을 할 때 우선적으로 적용됩니다. 또한, 해당 기술의 최신 동향과 관련된 새로운 위험이 발생함에 따라 이러한 윤리 원칙을 지속적으로 조정할 것을 약속합니다.
-- **투명성**: 우리는 PR을 관리하고, 사용자에게 우리의 선택을 설명하며, 기술적 의사결정을 내릴 때 투명성을 유지할 것을 약속합니다.
+- **투명성**: 우리는 PR 관리, 사용자에게 선택의 이유 설명, 기술적 의사결정 과정에서 투명성을 유지할 것을 약속합니다.
-- **일관성**: 우리는 프로젝트 관리에서 사용자들에게 동일한 수준의 관심을 보장하고 기술적으로 안정되고 일관된 상태를 유지할 것을 약속합니다.
+- **일관성**: 프로젝트 관리에서 모든 사용자에게 동일한 수준의 관심을 보장하고, 기술적으로 안정적이고 일관된 상태를 유지할 것을 약속합니다.
-- **간결성**: Diffusers 라이브러리를 사용하고 활용하기 쉽게 만들기 위해, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
+- **간결성**: Diffusers 라이브러리를 쉽게 사용하고 활용할 수 있도록, 프로젝트의 목표를 간결하고 일관성 있게 유지할 것을 약속합니다.
-- **접근성**: Diffusers 프로젝트는 기술적 전문 지식 없어도 프로젝트 운영에 참여할 수 있는 기여자의 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근할 수 있게 됩니다.
+- **접근성**: Diffusers 프로젝트는 기술적 전문지식이 없어도 기여할 수 있도록 진입장벽을 낮춥니다. 이를 통해 연구 결과물이 커뮤니티에 더 잘 접근될 수 있습니다.
-- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림(upstream) 코드, 모델 및 데이터셋의 재현성에 대해 투명하게 공개할 것을 목표로 합니다.
-
-- **책임**: 우리는 커뮤니티와 팀워크를 통해, 이 기술의 잠재적인 위험과 위험을 예측하고 완화하는 데 대한 공동 책임을 가지고 있습니다.
+- **재현성**: 우리는 Diffusers 라이브러리를 통해 제공되는 업스트림 코드, 모델, 데이터셋의 재현성에 대해 투명하게 공개하는 것을 목표로 합니다.
+- **책임**: 커뮤니티와 팀워크를 통해, 이 기술의 잠재적 위험을 예측하고 완화하는 데 공동 책임을 집니다.
## 구현 사례: 안전 기능과 메커니즘 [[examples-of-implementations-safety-features-and-mechanisms]]
-팀은 diffusion 기술과 관련된 잠재적인 윤리 및 사회적 위험에 대처하기 위한 기술적 및 비기술적 도구를 제공하고자 하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능의 구현하고 우리와 함께 인식을 높이는 데 매우 중요합니다.
+팀은 diffusion 기술과 관련된 잠재적 윤리 및 사회적 위험에 대응하기 위해 기술적·비기술적 도구를 제공하고자 노력하고 있습니다. 또한, 커뮤니티의 참여는 이러한 기능 구현과 인식 제고에 매우 중요합니다.
-- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 이를 통해 커뮤니티는 프로젝트에 대해 토론하고 더 나은 협력을 할 수 있습니다.
+- [**커뮤니티 탭**](https://huggingface.co/docs/hub/repositories-pull-requests-discussions): 커뮤니티가 프로젝트에 대해 토론하고 더 나은 협업을 할 수 있도록 지원합니다.
-- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)을 제공합니다. 이런 의미에서, 우리는 편향 탐색 및 평가를 지원하고 장려합니다.
+- **편향 탐색 및 평가**: Hugging Face 팀은 Stable Diffusion 모델의 편향성을 대화형으로 보여주는 [space](https://huggingface.co/spaces/society-ethics/DiffusionBiasExplorer)를 제공합니다. 우리는 이러한 편향 탐색과 평가를 지원하고 장려합니다.
- **배포에서의 안전 유도**
- - [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 이는 필터되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절한 변질에 취약한 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
+ - [**안전한 Stable Diffusion**](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_safe): 필터링되지 않은 웹 크롤링 데이터셋으로 훈련된 Stable Diffusion과 같은 모델이 부적절하게 변질되는 문제를 완화합니다. 관련 논문: [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105).
- - [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 이미지가 생성된 후에 이미자가 임베딩 공간에서 일련의 하드코딩된 유해 개념의 클래스일 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
+ - [**안전 검사기**](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py): 생성된 이미지가 임베딩 공간에서 하드코딩된 유해 개념 클래스와 일치할 확률을 확인하고 비교합니다. 유해 개념은 역공학을 방지하기 위해 의도적으로 숨겨져 있습니다.
-- **Hub에서의 단계적인 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한해야 합니다. 이 단계적인 배포는 중간 단계로, 리포지토리 작성자가 사용에 대한 더 많은 통제력을 갖게 합니다.
+- **Hub에서의 단계적 배포**: 특히 민감한 상황에서는 일부 리포지토리에 대한 접근을 제한할 수 있습니다. 단계적 배포는 리포지토리 작성자가 사용에 대해 더 많은 통제권을 갖도록 하는 중간 단계입니다.
-- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선싱을 통해 자유로운 접근을 보장하면서도 더 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
+- **라이선싱**: [OpenRAILs](https://huggingface.co/blog/open_rail)와 같은 새로운 유형의 라이선스를 통해 자유로운 접근을 보장하면서도 보다 책임 있는 사용을 위한 일련의 제한을 둘 수 있습니다.
diff --git a/docs/source/ko/conceptual/evaluation.md b/docs/source/ko/conceptual/evaluation.md
index 2d296420bc..731b511485 100644
--- a/docs/source/ko/conceptual/evaluation.md
+++ b/docs/source/ko/conceptual/evaluation.md
@@ -95,11 +95,8 @@ images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generato
다양한 모델을 사용하여 모든 프롬프트에서 생성된 여러 이미지들이 생성되면 (평가 과정에서) 이러한 결과물들은 사람 평가자들에게 점수를 매기기 위해 제시됩니다. DrawBench와 PartiPrompts 벤치마크에 대한 자세한 내용은 각각의 논문을 참조하십시오.
-
-
-모델이 훈련 중일 때 추론 샘플을 살펴보는 것은 훈련 진행 상황을 측정하는 데 유용합니다. [훈련 스크립트](https://github.com/huggingface/diffusers/tree/main/examples/)에서는 TensorBoard와 Weights & Biases에 대한 추가 지원과 함께 이 유틸리티를 지원합니다.
-
-
+> [!TIP]
+> 모델이 훈련 중일 때 추론 샘플을 살펴보는 것은 훈련 진행 상황을 측정하는 데 유용합니다. [훈련 스크립트](https://github.com/huggingface/diffusers/tree/main/examples/)에서는 TensorBoard와 Weights & Biases에 대한 추가 지원과 함께 이 유틸리티를 지원합니다.
## 정량적 평가[[quantitative-evaluation]]
@@ -193,11 +190,8 @@ print(f"CLIP Score with v-1-5: {sd_clip_score_1_5}")
[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트가 이전 버전보다 더 나은 성능을 보이는 것 같습니다. 그러나 CLIP 점수를 계산하기 위해 사용한 프롬프트의 수가 상당히 적습니다. 보다 실용적인 평가를 위해서는 이 수를 훨씬 높게 설정하고, 프롬프트를 다양하게 사용해야 합니다.
-
-
-이 점수에는 몇 가지 제한 사항이 있습니다. 훈련 데이터셋의 캡션은 웹에서 크롤링되어 이미지와 관련된 `alt` 및 유사한 태그에서 추출되었습니다. 이들은 인간이 이미지를 설명하는 데 사용할 수 있는 것과 일치하지 않을 수 있습니다. 따라서 여기서는 몇 가지 프롬프트를 "엔지니어링"해야 했습니다.
-
-
+> [!WARNING]
+> 이 점수에는 몇 가지 제한 사항이 있습니다. 훈련 데이터셋의 캡션은 웹에서 크롤링되어 이미지와 관련된 `alt` 및 유사한 태그에서 추출되었습니다. 이들은 인간이 이미지를 설명하는 데 사용할 수 있는 것과 일치하지 않을 수 있습니다. 따라서 여기서는 몇 가지 프롬프트를 "엔지니어링"해야 했습니다.
### 이미지 조건화된 텍스트-이미지 생성[[image-conditioned-text-to-image-generation]]
@@ -405,11 +399,8 @@ CLIP 점수와 마찬가지로, CLIP 방향 유사성이 높을수록 좋습니
[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)와 같은 유사한 파이프라인에도 이러한 메트릭을 사용할 수 있습니다.
-
-
-CLIP 점수와 CLIP 방향 유사성 모두 CLIP 모델에 의존하기 때문에 평가가 편향될 수 있습니다
-
-
+> [!TIP]
+> CLIP 점수와 CLIP 방향 유사성 모두 CLIP 모델에 의존하기 때문에 평가가 편향될 수 있습니다
***IS, FID (나중에 설명할 예정), 또는 KID와 같은 메트릭을 확장하는 것은 어려울 수 있습니다***. 평가 중인 모델이 대규모 이미지 캡셔닝 데이터셋 (예: [LAION-5B 데이터셋](https://laion.ai/blog/laion-5b/))에서 사전 훈련되었을 때 이는 문제가 될 수 있습니다. 왜냐하면 이러한 메트릭의 기반에는 중간 이미지 특징을 추출하기 위해 ImageNet-1k 데이터셋에서 사전 훈련된 InceptionNet이 사용되기 때문입니다. Stable Diffusion의 사전 훈련 데이터셋은 InceptionNet의 사전 훈련 데이터셋과 겹치는 부분이 제한적일 수 있으므로 따라서 여기에는 좋은 후보가 아닙니다.
@@ -532,19 +523,16 @@ FID는 낮을수록 좋습니다. 여러 가지 요소가 FID에 영향을 줄
마지막 두 가지 요소에 대해서는, 다른 시드와 추론 단계에서 평가를 실행하고 평균 결과를 보고하는 것은 좋은 실천 방법입니다
-
-
-FID 결과는 많은 요소에 의존하기 때문에 취약할 수 있습니다:
-
-* 계산 중 사용되는 특정 Inception 모델.
-* 계산의 구현 정확도.
-* 이미지 형식 (PNG 또는 JPG에서 시작하는 경우가 다릅니다).
-
-이러한 사항을 염두에 두면, FID는 유사한 실행을 비교할 때 가장 유용하지만, 저자가 FID 측정 코드를 주의 깊게 공개하지 않는 한 논문 결과를 재현하기는 어렵습니다.
-
-이러한 사항은 KID 및 IS와 같은 다른 관련 메트릭에도 적용됩니다.
-
-
+> [!WARNING]
+> FID 결과는 많은 요소에 의존하기 때문에 취약할 수 있습니다:
+>
+> * 계산 중 사용되는 특정 Inception 모델.
+> * 계산의 구현 정확도.
+> * 이미지 형식 (PNG 또는 JPG에서 시작하는 경우가 다릅니다).
+>
+> 이러한 사항을 염두에 두면, FID는 유사한 실행을 비교할 때 가장 유용하지만, 저자가 FID 측정 코드를 주의 깊게 공개하지 않는 한 논문 결과를 재현하기는 어렵습니다.
+>
+> 이러한 사항은 KID 및 IS와 같은 다른 관련 메트릭에도 적용됩니다.
마지막 단계로, `fake_images`를 시각적으로 검사해 봅시다.
diff --git a/docs/source/ko/installation.md b/docs/source/ko/installation.md
index c03b464290..198ca4b7c7 100644
--- a/docs/source/ko/installation.md
+++ b/docs/source/ko/installation.md
@@ -107,11 +107,8 @@ pip install -e ".[flax]"
Python은 이제 일반 라이브러리 경로에 더하여 복제한 폴더 내부를 살펴봅니다.
예를들어 Python 패키지가 `~/anaconda3/envs/main/lib/python3.10/site-packages/`에 설치되어 있는 경우 Python은 복제한 폴더인 `~/diffusers/`도 검색합니다.
-
-
-라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.
-
-
+> [!WARNING]
+> 라이브러리를 계속 사용하려면 `diffusers` 폴더를 유지해야 합니다.
이제 다음 명령어를 사용하여 최신 버전의 🤗 Diffusers로 쉽게 업데이트할 수 있습니다:
diff --git a/docs/source/ko/optimization/coreml.md b/docs/source/ko/optimization/coreml.md
index 60f19fd2c3..73ca851177 100644
--- a/docs/source/ko/optimization/coreml.md
+++ b/docs/source/ko/optimization/coreml.md
@@ -16,11 +16,8 @@ specific language governing permissions and limitations under the License.
Core ML 모델은 Apple 기기에서 사용할 수 있는 모든 컴퓨팅 엔진들, 즉 CPU, GPU, Apple Neural Engine(또는 Apple Silicon Mac 및 최신 iPhone/iPad에서 사용할 수 있는 텐서 최적화 가속기인 ANE)을 활용할 수 있습니다. 모델과 실행 중인 기기에 따라 Core ML은 컴퓨팅 엔진도 혼합하여 사용할 수 있으므로, 예를 들어 모델의 일부가 CPU에서 실행되는 반면 다른 부분은 GPU에서 실행될 수 있습니다.
-
-
-PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.
-
-
+> [!TIP]
+> PyTorch에 내장된 `mps` 가속기를 사용하여 Apple Silicon Macs에서 `diffusers` Python 코드베이스를 실행할 수도 있습니다. 이 방법은 [mps 가이드]에 자세히 설명되어 있지만 네이티브 앱과 호환되지 않습니다.
## Stable Diffusion Core ML 체크포인트
diff --git a/docs/source/ko/optimization/fp16.md b/docs/source/ko/optimization/fp16.md
index db0370875e..56f1330c40 100644
--- a/docs/source/ko/optimization/fp16.md
+++ b/docs/source/ko/optimization/fp16.md
@@ -74,18 +74,16 @@ prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
```
-
- 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.
-
+> [!WARNING]
+> 어떤 파이프라인에서도 [`torch.autocast`](https://pytorch.org/docs/stable/amp.html#torch.autocast) 를 사용하는 것은 검은색 이미지를 생성할 수 있고, 순수한 float16 정밀도를 사용하는 것보다 항상 느리기 때문에 사용하지 않는 것이 좋습니다.
## 추가 메모리 절약을 위한 슬라이스 어텐션
추가 메모리 절약을 위해, 한 번에 모두 계산하는 대신 단계적으로 계산을 수행하는 슬라이스 버전의 어텐션(attention)을 사용할 수 있습니다.
-
- Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.
- 하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.
-
+> [!TIP]
+> Attention slicing은 모델이 하나 이상의 어텐션 헤드를 사용하는 한, 배치 크기가 1인 경우에도 유용합니다.
+> 하나 이상의 어텐션 헤드가 있는 경우 *QK^T* 어텐션 매트릭스는 상당한 양의 메모리를 절약할 수 있는 각 헤드에 대해 순차적으로 계산될 수 있습니다.
각 헤드에 대해 순차적으로 어텐션 계산을 수행하려면, 다음과 같이 추론 전에 파이프라인에서 [`~StableDiffusionPipeline.enable_attention_slicing`]를 호출하면 됩니다:
@@ -161,9 +159,8 @@ image = pipe(prompt).images[0]
참고로 이 방법은 전체 모델이 아닌 서브모듈 수준에서 작동합니다. 이는 메모리 소비를 최소화하는 가장 좋은 방법이지만 프로세스의 반복적 특성으로 인해 추론 속도가 훨씬 느립니다. 파이프라인의 UNet 구성 요소는 여러 번 실행됩니다('num_inference_steps' 만큼). 매번 UNet의 서로 다른 서브모듈이 순차적으로 온로드된 다음 필요에 따라 오프로드되므로 메모리 이동 횟수가 많습니다.
-
-또 다른 최적화 방법인 모델 오프로딩을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.
-
+> [!TIP]
+> 또 다른 최적화 방법인 모델 오프로딩을 사용하는 것을 고려하십시오. 이는 훨씬 빠르지만 메모리 절약이 크지는 않습니다.
또한 ttention slicing과 연결해서 최소 메모리(< 2GB)로도 동작할 수 있습니다.
@@ -231,9 +228,8 @@ pipe.enable_attention_slicing(1)
image = pipe(prompt).images[0]
```
-
-이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.
-
+> [!TIP]
+> 이 기능을 사용하려면 'accelerate' 버전 0.17.0 이상이 필요합니다.
## Channels Last 메모리 형식 사용하기
diff --git a/docs/source/ko/optimization/mps.md b/docs/source/ko/optimization/mps.md
index 4daeaf5dba..004374c4af 100644
--- a/docs/source/ko/optimization/mps.md
+++ b/docs/source/ko/optimization/mps.md
@@ -27,11 +27,8 @@ Diffusers는 Stable Diffusion 추론을 위해 PyTorch `mps`를 사용해 Apple
아래 코도는 익숙한 `to()` 인터페이스를 사용하여 `mps` 백엔드로 Stable Diffusion 파이프라인을 M1 또는 M2 장치로 이동하는 방법을 보여줍니다.
-
-
-**PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 "프라이밍"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.
-
-
+> [!WARNING]
+> **PyTorch 1.13을 사용 중일 때 ** 추가 일회성 전달을 사용하여 파이프라인을 "프라이밍"하는 것을 추천합니다. 이것은 발견한 이상한 문제에 대한 임시 해결 방법입니다. 첫 번째 추론 전달은 후속 전달와 약간 다른 결과를 생성합니다. 이 전달은 한 번만 수행하면 되며 추론 단계를 한 번만 사용하고 결과를 폐기해도 됩니다.
이전 팁에서 설명한 것들을 포함한 여러 문제를 해결하므로 PyTorch 2 이상을 사용하는 것이 좋습니다.
diff --git a/docs/source/ko/optimization/torch2.0.md b/docs/source/ko/optimization/torch2.0.md
index c78c4a87b6..354f7243cf 100644
--- a/docs/source/ko/optimization/torch2.0.md
+++ b/docs/source/ko/optimization/torch2.0.md
@@ -173,7 +173,7 @@ mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data
init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
-path = "runwayml/stable-diffusion-inpainting"
+path = "stable-diffusion-v1-5/stable-diffusion-inpainting"
run_compile = True # Set True / False
diff --git a/docs/source/ko/optimization/xformers.md b/docs/source/ko/optimization/xformers.md
index 3e4d107c0a..96fab34acf 100644
--- a/docs/source/ko/optimization/xformers.md
+++ b/docs/source/ko/optimization/xformers.md
@@ -21,16 +21,10 @@ specific language governing permissions and limitations under the License.
pip install xformers
```
-
-
-xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.
-
-
+> [!TIP]
+> xFormers PIP 패키지에는 최신 버전의 PyTorch(xFormers 0.0.16에 1.13.1)가 필요합니다. 이전 버전의 PyTorch를 사용해야 하는 경우 [프로젝트 지침](https://github.com/facebookresearch/xformers#installing-xformers)의 소스를 사용해 xFormers를 설치하는 것이 좋습니다.
xFormers를 설치하면, [여기](fp16#memory-efficient-attention)서 설명한 것처럼 'enable_xformers_memory_efficient_attention()'을 사용하여 추론 속도를 높이고 메모리 소비를 줄일 수 있습니다.
-
-
-[이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.
-
-
+> [!WARNING]
+> [이 이슈](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)에 따르면 xFormers `v0.0.16`에서 GPU를 사용한 학습(파인 튜닝 또는 Dreambooth)을 할 수 없습니다. 해당 문제가 발견되면. 해당 코멘트를 참고해 development 버전을 설치하세요.
diff --git a/docs/source/ko/quicktour.md b/docs/source/ko/quicktour.md
index 58ebb8960f..0a3cd0f7c4 100644
--- a/docs/source/ko/quicktour.md
+++ b/docs/source/ko/quicktour.md
@@ -23,11 +23,8 @@ Diffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성
훑어보기에서는 추론을 위해 [`DiffusionPipeline`]을 사용하는 방법을 보여준 다음, 모델과 스케줄러를 결합하여 [`DiffusionPipeline`] 내부에서 일어나는 일을 복제하는 방법을 안내합니다.
-
-
-훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!
-
-
+> [!TIP]
+> 훑어보기는 간결한 버전의 🧨 Diffusers 소개로서 [노트북](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) 빠르게 시작할 수 있도록 도와드립니다. 디퓨저의 목표, 디자인 철학, 핵심 API에 대한 추가 세부 정보를 자세히 알아보려면 노트북을 확인하세요!
시작하기 전에 필요한 라이브러리가 모두 설치되어 있는지 확인하세요:
@@ -55,11 +52,8 @@ Diffusion 모델은 이미지나 오디오와 같은 관심 샘플들을 생성
허깅페이스 허브에 저장된 모든 [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads)에 대해 [`DiffusionPipeline`]을 사용할 수 있습니다.
이 훑어보기에서는 text-to-image 생성을 위한 [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 체크포인트를 로드합니다.
-
-
-[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.
-
-
+> [!WARNING]
+> [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion) 모델의 경우, 모델을 실행하기 전에 [라이선스](https://huggingface.co/spaces/CompVis/stable-diffusion-license)를 먼저 주의 깊게 읽어주세요. 🧨 Diffusers는 불쾌하거나 유해한 콘텐츠를 방지하기 위해 [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py)를 구현하고 있지만, 모델의 향상된 이미지 생성 기능으로 인해 여전히 잠재적으로 유해한 콘텐츠가 생성될 수 있습니다.
[`~DiffusionPipeline.from_pretrained`] 방법으로 모델 로드하기:
@@ -203,11 +197,8 @@ torch.Size([1, 3, 256, 256])
스케줄러는 모델 출력이 주어졌을 때 노이즈가 많은 샘플에서 노이즈가 적은 샘플로 전환하는 것을 관리합니다 - 이 경우 'noisy_residual'.
-
-
-🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.
-
-
+> [!TIP]
+> 🧨 Diffusers는 Diffusion 시스템을 구축하기 위한 툴박스입니다. [`DiffusionPipeline`]을 사용하면 미리 만들어진 Diffusion 시스템을 편리하게 시작할 수 있지만, 모델과 스케줄러 구성 요소를 개별적으로 선택하여 사용자 지정 Diffusion 시스템을 구축할 수도 있습니다.
훑어보기의 경우, [`~diffusers.ConfigMixin.from_config`] 메서드를 사용하여 [`DDPMScheduler`]를 인스턴스화합니다:
@@ -231,11 +222,8 @@ DDPMScheduler {
}
```
-
-
-💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!
-
-
+> [!TIP]
+> 💡 스케줄러가 구성에서 어떻게 인스턴스화되는지 주목하세요. 모델과 달리 스케줄러에는 학습 가능한 가중치가 없으며 매개변수도 없습니다!
가장 중요한 매개변수는 다음과 같습니다:
diff --git a/docs/source/ko/stable_diffusion.md b/docs/source/ko/stable_diffusion.md
index 794bdf9c66..0f61e16d2a 100644
--- a/docs/source/ko/stable_diffusion.md
+++ b/docs/source/ko/stable_diffusion.md
@@ -37,11 +37,8 @@ prompt = "portrait photo of a old warrior chief"
## 속도
-
-
-💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)
-
-
+> [!TIP]
+> 💡 GPU에 액세스할 수 없는 경우 다음과 같은 GPU 제공업체에서 무료로 사용할 수 있습니다!. [Colab](https://colab.research.google.com/)
추론 속도를 높이는 가장 간단한 방법 중 하나는 Pytorch 모듈을 사용할 때와 같은 방식으로 GPU에 파이프라인을 배치하는 것입니다:
@@ -89,11 +86,8 @@ image
이번에는 이미지를 생성하는 데 약 11초밖에 걸리지 않아 이전보다 3배 가까이 빨라졌습니다!
-
-
-💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.
-
-
+> [!TIP]
+> 💡 파이프라인은 항상 `float16`에서 실행할 것을 강력히 권장하며, 지금까지 출력 품질이 저하되는 경우는 거의 없었습니다.
또 다른 옵션은 추론 단계의 수를 줄이는 것입니다. 보다 효율적인 스케줄러를 선택하면 출력 품질 저하 없이 단계 수를 줄이는 데 도움이 될 수 있습니다. 현재 모델과 호환되는 스케줄러는 `compatibles` 메서드를 호출하여 [`DiffusionPipeline`]에서 찾을 수 있습니다:
diff --git a/docs/source/ko/training/adapt_a_model.md b/docs/source/ko/training/adapt_a_model.md
index 3795558f5f..fe6fde05b7 100644
--- a/docs/source/ko/training/adapt_a_model.md
+++ b/docs/source/ko/training/adapt_a_model.md
@@ -28,12 +28,12 @@ pipeline.unet.config["in_channels"]
4
```
-인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:
+인페인팅은 입력 샘플에 9개의 채널이 필요합니다. [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)와 같은 사전학습된 인페인팅 모델에서 이 값을 확인할 수 있습니다:
```py
from diffusers import StableDiffusionPipeline
-pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting")
pipeline.unet.config["in_channels"]
9
```
diff --git a/docs/source/ko/training/controlnet.md b/docs/source/ko/training/controlnet.md
index 434ca959bd..e868b57c55 100644
--- a/docs/source/ko/training/controlnet.md
+++ b/docs/source/ko/training/controlnet.md
@@ -20,11 +20,8 @@ specific language governing permissions and limitations under the License.
아래의 스크립트를 실행하기 전에, 라이브러리의 학습 의존성을 설치해야 합니다.
-
-
-가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.
-
-
+> [!WARNING]
+> 가장 최신 버전의 예시 스크립트를 성공적으로 실행하기 위해서는, 소스에서 설치하고 최신 버전의 설치를 유지하는 것을 강력하게 추천합니다. 우리는 예시 스크립트들을 자주 업데이트하고 예시에 맞춘 특정한 요구사항을 설치합니다.
위 사항을 만족시키기 위해서, 새로운 가상환경에서 다음 일련의 스텝을 실행하세요:
diff --git a/docs/source/ko/training/create_dataset.md b/docs/source/ko/training/create_dataset.md
index a869cd09f0..c459a9d6a1 100644
--- a/docs/source/ko/training/create_dataset.md
+++ b/docs/source/ko/training/create_dataset.md
@@ -11,11 +11,8 @@
- 이미지 폴더를 `--train_data_dir` 인수에 제공합니다.
- 데이터셋을 Hub에 업로드하고 데이터셋 리포지토리 id를 `--dataset_name` 인수에 전달합니다.
-
-
-💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.
-
-
+> [!TIP]
+> 💡 학습에 사용할 이미지 데이터셋을 만드는 방법에 대한 자세한 내용은 [이미지 데이터셋 만들기](https://huggingface.co/docs/datasets/image_dataset) 가이드를 참고하세요.
## 폴더 형태로 데이터셋 구축하기
@@ -40,11 +37,8 @@ accelerate launch train_unconditional.py \
## Hub에 데이터 올리기
-
-
-💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.
-
-
+> [!TIP]
+> 💡 데이터셋을 만들고 Hub에 업로드하는 것에 대한 자세한 내용은 [🤗 Datasets을 사용한 이미지 검색](https://huggingface.co/blog/image-search-datasets) 게시물을 참고하세요.
PIL 인코딩된 이미지가 포함된 `이미지` 열을 생성하는 [이미지 폴더](https://huggingface.co/docs/datasets/image_load#imagefolder) 기능을 사용하여 데이터셋 생성을 시작합니다.
diff --git a/docs/source/ko/training/distributed_inference.md b/docs/source/ko/training/distributed_inference.md
index c4d6400d97..e63764f5eb 100644
--- a/docs/source/ko/training/distributed_inference.md
+++ b/docs/source/ko/training/distributed_inference.md
@@ -32,9 +32,8 @@ Use the `--num_processes` argument to specify the number of GPUs to use, and cal
accelerate launch run_distributed.py --num_processes=2
```
-자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.
-
-
+> [!TIP]
+> 자세한 내용은 [🤗 Accelerate를 사용한 분산 추론](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 가이드를 참조하세요.
## Pytoerch 분산
diff --git a/docs/source/ko/training/dreambooth.md b/docs/source/ko/training/dreambooth.md
index 8e62f8edab..3e5a17d5f6 100644
--- a/docs/source/ko/training/dreambooth.md
+++ b/docs/source/ko/training/dreambooth.md
@@ -51,11 +51,8 @@ write_basic_config()
## 파인튜닝
-
-
-DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.
-
-
+> [!WARNING]
+> DreamBooth 파인튜닝은 하이퍼파라미터에 매우 민감하고 과적합되기 쉽습니다. 적절한 하이퍼파라미터를 선택하는 데 도움이 되도록 다양한 권장 설정이 포함된 [심층 분석](https://huggingface.co/blog/dreambooth)을 살펴보는 것이 좋습니다.
@@ -176,11 +173,8 @@ python train_dreambooth_flax.py \
해당 스크립트를 사용하면 `unet`과 함께 `text_encoder`를 파인튜닝할 수 있습니다. 실험에서(자세한 내용은 [🧨 Diffusers를 사용해 DreamBooth로 Stable Diffusion 학습하기](https://huggingface.co/blog/dreambooth) 게시물을 확인하세요), 특히 얼굴 이미지를 생성할 때 훨씬 더 나은 결과를 얻을 수 있습니다.
-
-
-텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.
-
-
+> [!WARNING]
+> 텍스트 인코더를 학습시키려면 추가 메모리가 필요해 16GB GPU로는 동작하지 않습니다. 이 옵션을 사용하려면 최소 24GB VRAM이 필요합니다.
`--train_text_encoder` 인수를 학습 스크립트에 전달하여 `text_encoder` 및 `unet`을 파인튜닝할 수 있습니다:
diff --git a/docs/source/ko/training/lora.md b/docs/source/ko/training/lora.md
index 5bcef27143..515e3fd65e 100644
--- a/docs/source/ko/training/lora.md
+++ b/docs/source/ko/training/lora.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-
-
-현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.
-
-
+> [!WARNING]
+> 현재 LoRA는 [`UNet2DConditionalModel`]의 어텐션 레이어에서만 지원됩니다.
[LoRA(Low-Rank Adaptation of Large Language Models)](https://huggingface.co/papers/2106.09685)는 메모리를 적게 사용하면서 대규모 모델의 학습을 가속화하는 학습 방법입니다. 이는 rank-decomposition weight 행렬 쌍(**업데이트 행렬**이라고 함)을 추가하고 새로 추가된 가중치**만** 학습합니다. 여기에는 몇 가지 장점이 있습니다.
@@ -28,11 +25,8 @@ specific language governing permissions and limitations under the License.
- 메모리 효율성이 향상되어 Tesla T4, RTX 3080 또는 RTX 2080 Ti와 같은 소비자용 GPU에서 파인튜닝을 실행할 수 있습니다! T4와 같은 GPU는 무료이며 Kaggle 또는 Google Colab 노트북에서 쉽게 액세스할 수 있습니다.
-
-
-💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!
-
-
+> [!TIP]
+> 💡 LoRA는 어텐션 레이어에만 한정되지는 않습니다. 저자는 언어 모델의 어텐션 레이어를 수정하는 것이 매우 효율적으로 죻은 성능을 얻기에 충분하다는 것을 발견했습니다. 이것이 LoRA 가중치를 모델의 어텐션 레이어에 추가하는 것이 일반적인 이유입니다. LoRA 작동 방식에 대한 자세한 내용은 [Using LoRA for effective Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) 블로그를 확인하세요!
[cloneofsimo](https://github.com/cloneofsimo)는 인기 있는 [lora](https://github.com/cloneofsimo/lora) GitHub 리포지토리에서 Stable Diffusion을 위한 LoRA 학습을 최초로 시도했습니다. 🧨 Diffusers는 [text-to-image 생성](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) 및 [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)을 지원합니다. 이 가이드는 두 가지를 모두 수행하는 방법을 보여줍니다.
@@ -104,11 +98,8 @@ accelerate launch train_dreambooth_lora.py \
*기본 모델의 가중치 위에* 파인튜닝된 DreamBooth 모델에서 LoRA 가중치를 불러온 다음, 더 빠른 추론을 위해 파이프라인을 GPU로 이동합니다. LoRA 가중치를 프리징된 사전 훈련된 모델 가중치와 병합할 때, 선택적으로 'scale' 매개변수로 어느 정도의 가중치를 병합할 지 조절할 수 있습니다:
-
-
-💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.
-
-
+> [!TIP]
+> 💡 `0`의 `scale` 값은 LoRA 가중치를 사용하지 않아 원래 모델의 가중치만 사용한 것과 같고, `1`의 `scale` 값은 파인튜닝된 LoRA 가중치만 사용함을 의미합니다. 0과 1 사이의 값들은 두 결과들 사이로 보간됩니다.
```py
>>> pipe.unet.load_attn_procs(model_path)
diff --git a/docs/source/ko/training/text2image.md b/docs/source/ko/training/text2image.md
index 4283f73ed9..b26603bf1b 100644
--- a/docs/source/ko/training/text2image.md
+++ b/docs/source/ko/training/text2image.md
@@ -13,11 +13,8 @@ specific language governing permissions and limitations under the License.
# Text-to-image
-
-
-text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.
-
-
+> [!WARNING]
+> text-to-image 파인튜닝 스크립트는 experimental 상태입니다. 과적합하기 쉽고 치명적인 망각과 같은 문제에 부딪히기 쉽습니다. 자체 데이터셋에서 최상의 결과를 얻으려면 다양한 하이퍼파라미터를 탐색하는 것이 좋습니다.
Stable Diffusion과 같은 text-to-image 모델은 텍스트 프롬프트에서 이미지를 생성합니다. 이 가이드는 PyTorch 및 Flax를 사용하여 자체 데이터셋에서 [`CompVis/stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) 모델로 파인튜닝하는 방법을 보여줍니다. 이 가이드에 사용된 text-to-image 파인튜닝을 위한 모든 학습 스크립트에 관심이 있는 경우 이 [리포지토리](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)에서 자세히 찾을 수 있습니다.
diff --git a/docs/source/ko/training/text_inversion.md b/docs/source/ko/training/text_inversion.md
index b27bed7d14..d8b44930e3 100644
--- a/docs/source/ko/training/text_inversion.md
+++ b/docs/source/ko/training/text_inversion.md
@@ -23,11 +23,8 @@ specific language governing permissions and limitations under the License.
이 가이드에서는 textual-inversion으로 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델을 학습하는 방법을 설명합니다. 이 가이드에서 사용된 모든 textual-inversion 학습 스크립트는 [여기](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)에서 확인할 수 있습니다. 내부적으로 어떻게 작동하는지 자세히 살펴보고 싶으시다면 해당 링크를 참조해주시기 바랍니다.
-
-
-[Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!
-
-
+> [!TIP]
+> [Stable Diffusion Textual Inversion Concepts Library](https://huggingface.co/sd-concepts-library)에는 커뮤니티에서 제작한 학습된 textual-inversion 모델들이 있습니다. 시간이 지남에 따라 더 많은 콘셉트들이 추가되어 유용한 리소스로 성장할 것입니다!
시작하기 전에 학습을 위한 의존성 라이브러리들을 설치해야 합니다:
@@ -100,11 +97,8 @@ snapshot_download(
- `token_identifier.txt`
- `type_of_concept.txt`.
-
-
-💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !
-
-
+> [!TIP]
+> 💡V100 GPU 1개를 기준으로 전체 학습에는 최대 1시간이 걸립니다. 학습이 완료되기를 기다리는 동안 궁금한 점이 있으면 아래 섹션에서 [textual-inversion이 어떻게 작동하는지](https://huggingface.co/docs/diffusers/training/text_inversion#how-it-works) 자유롭게 확인하세요 !
@@ -128,15 +122,12 @@ accelerate launch textual_inversion.py \
--push_to_hub
```
-
-
-💡학습 성능을 올리기 위해, 플레이스홀더 토큰(``)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다. 이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.
-
-```bash
---num_vectors=5
-```
-
-
+> [!TIP]
+> 💡학습 성능을 올리기 위해, 플레이스홀더 토큰(``)을 (단일한 임베딩 벡터가 아닌) 복수의 임베딩 벡터로 표현하는 것 역시 고려할 있습니다. 이러한 트릭이 모델이 보다 복잡한 이미지의 스타일(앞서 말한 콘셉트)을 더 잘 캡처하는 데 도움이 될 수 있습니다. 복수의 임베딩 벡터 학습을 활성화하려면 다음 옵션을 전달하십시오.
+>
+> ```bash
+> --num_vectors=5
+> ```
@@ -193,11 +184,8 @@ textual-inversion 스크립트는 기본적으로 textual-inversion을 통해
-
-
-💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.
-
-
+> [!TIP]
+> 💡 커뮤니티는 [sd-concepts-library](https://huggingface.co/sd-concepts-library) 라는 대규모의 textual-inversion 임베딩 벡터 라이브러리를 만들었습니다. textual-inversion 임베딩을 밑바닥부터 학습하는 대신, 해당 라이브러리에 본인이 찾는 textual-inversion 임베딩이 이미 추가되어 있지 않은지를 확인하는 것도 좋은 방법이 될 것 같습니다.
textual-inversion 임베딩 벡터을 불러오기 위해서는, 먼저 해당 임베딩 벡터를 학습할 때 사용한 모델을 불러와야 합니다. 여기서는 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/docs/diffusers/training/stable-diffusion-v1-5/stable-diffusion-v1-5) 모델이 사용되었다고 가정하고 불러오겠습니다.
diff --git a/docs/source/ko/training/unconditional_training.md b/docs/source/ko/training/unconditional_training.md
index c8c463da6b..04a9a6c7ea 100644
--- a/docs/source/ko/training/unconditional_training.md
+++ b/docs/source/ko/training/unconditional_training.md
@@ -78,11 +78,8 @@ write_basic_config()
학습 스크립트는 `diffusion_pytorch_model.bin` 파일을 생성하고, 그것을 당신의 리포지토리에 저장합니다.
-
-
-💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.
-
-
+> [!TIP]
+> 💡 전체 학습은 V100 GPU 4개를 사용할 경우, 2시간이 소요됩니다.
예를 들어, [Oxford Flowers](https://huggingface.co/datasets/huggan/flowers-102-categories) 데이터셋을 사용해 파인튜닝할 경우:
diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md
index 2c4c89edd1..05ce1037b5 100644
--- a/docs/source/ko/tutorials/basic_training.md
+++ b/docs/source/ko/tutorials/basic_training.md
@@ -19,11 +19,8 @@ Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한
이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) 데이터셋의 하위 집합에서 [`UNet2DModel`] 모델을 학습하는 방법을 가르쳐줄 것입니다.
-
-
-💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!
-
-
+> [!TIP]
+> 💡 이 학습 튜토리얼은 [Training with 🧨 Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!
시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 [TensorBoard](https://www.tensorflow.org/tensorboard)를 또한 설치하세요. (또한 학습 추적을 위해 [Weights & Biases](https://docs.wandb.ai/)를 사용할 수 있습니다.)
diff --git a/docs/source/ko/using-diffusers/controlling_generation.md b/docs/source/ko/using-diffusers/controlling_generation.md
index 1b9a8b5df5..db22fe042d 100644
--- a/docs/source/ko/using-diffusers/controlling_generation.md
+++ b/docs/source/ko/using-diffusers/controlling_generation.md
@@ -85,12 +85,9 @@ Pix2Pix Zero는 합성 이미지와 실제 이미지를 편집하는 데 모두
다음으로 편집할 컨셉과 새로운 타겟 컨셉에 대한 이미지 캡션을 생성합니다. 이를 위해 [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5)와 같은 모델을 사용할 수 있습니다. 그런 다음 텍스트 인코더를 통해 소스 개념과 대상 개념 모두에 대한 "평균" 프롬프트 임베딩을 생성합니다. 마지막으로, 합성 이미지를 편집하기 위해 pix2pix-zero 알고리즘을 사용합니다.
- 실제 이미지를 편집하려면 먼저 [BLIP](https://huggingface.co/docs/transformers/model_doc/blip)과 같은 모델을 사용하여 이미지 캡션을 생성합니다. 그런 다음 프롬프트와 이미지에 ddim 반전을 적용하여 "역(inverse)" latents을 생성합니다. 이전과 마찬가지로 소스 및 대상 개념 모두에 대한 "평균(mean)" 프롬프트 임베딩이 생성되고 마지막으로 "역(inverse)" latents와 결합된 pix2pix-zero 알고리즘이 이미지를 편집하는 데 사용됩니다.
-
-
-Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.
-즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).
-
-
+> [!TIP]
+> Pix2Pix Zero는 '제로 샷(zero-shot)' 이미지 편집이 가능한 최초의 모델입니다.
+> 즉, 이 모델은 다음과 같이 일반 소비자용 GPU에서 1분 이내에 이미지를 편집할 수 있습니다(../api/pipelines/stable_diffusion/pix2pix_zero#usage-example).
위에서 언급했듯이 Pix2Pix Zero에는 특정 개념으로 세대를 유도하기 위해 (UNet, VAE 또는 텍스트 인코더가 아닌) latents을 최적화하는 기능이 포함되어 있습니다.즉, 전체 파이프라인에 표준 [StableDiffusionPipeline](../api/pipelines/stable_diffusion/text2img)보다 더 많은 메모리가 필요할 수 있습니다.
@@ -140,13 +137,10 @@ SAG는 고빈도 세부 정보를 기반으로 하지 않은 예측에서 완전
사용 방법에 대한 자세한 내용은 [여기](../api/pipelines/stable_diffusion_2#depthtoimage)를 참조하세요.
-
-
-InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우
-는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.
-사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.
-
-
+> [!TIP]
+> InstructPix2Pix와 Pix2Pix Zero와 같은 방법의 중요한 차이점은 전자의 경우
+> 는 사전 학습된 가중치를 미세 조정하는 반면, 후자는 그렇지 않다는 것입니다. 즉, 다음을 수행할 수 있습니다.
+> 사용 가능한 모든 안정적 확산 모델에 Pix2Pix Zero를 적용할 수 있습니다.
## MultiDiffusion Panorama
diff --git a/docs/source/ko/using-diffusers/custom_pipeline_overview.md b/docs/source/ko/using-diffusers/custom_pipeline_overview.md
index b143bf8ab0..caeeca8cef 100644
--- a/docs/source/ko/using-diffusers/custom_pipeline_overview.md
+++ b/docs/source/ko/using-diffusers/custom_pipeline_overview.md
@@ -20,11 +20,8 @@ specific language governing permissions and limitations under the License.
허브에서 커뮤니티 파이프라인을 로드하려면, 커뮤니티 파이프라인의 리포지토리 ID와 (파이프라인 가중치 및 구성 요소를 로드하려는) 모델의 리포지토리 ID를 인자로 전달해야 합니다. 예를 들어, 아래 예시에서는 `hf-internal-testing/diffusers-dummy-pipeline`에서 더미 파이프라인을 불러오고, `google/ddpm-cifar10-32`에서 파이프라인의 가중치와 컴포넌트들을 로드합니다.
-
-
-🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!
-
-
+> [!WARNING]
+> 🔒 허깅 페이스 허브에서 커뮤니티 파이프라인을 불러오는 것은 곧 해당 코드가 안전하다고 신뢰하는 것입니다. 코드를 자동으로 불러오고 실행하기 앞서 반드시 온라인으로 해당 코드의 신뢰성을 검사하세요!
```py
from diffusers import DiffusionPipeline
diff --git a/docs/source/ko/using-diffusers/diffedit.md b/docs/source/ko/using-diffusers/diffedit.md
index 74b9e97831..edf23f0214 100644
--- a/docs/source/ko/using-diffusers/diffedit.md
+++ b/docs/source/ko/using-diffusers/diffedit.md
@@ -156,11 +156,8 @@ print(source_prompts)
print(target_prompts)
```
-
-
-다양한 품질의 텍스트를 생성하는 전략에 대해 자세히 알아보려면 [생성 전략](https://huggingface.co/docs/transformers/main/en/generation_strategies) 가이드를 참조하세요.
-
-
+> [!TIP]
+> 다양한 품질의 텍스트를 생성하는 전략에 대해 자세히 알아보려면 [생성 전략](https://huggingface.co/docs/transformers/main/en/generation_strategies) 가이드를 참조하세요.
텍스트 인코딩을 위해 [`StableDiffusionDiffEditPipeline`]에서 사용하는 텍스트 인코더 모델을 불러옵니다. 텍스트 인코더를 사용하여 텍스트 임베딩을 계산합니다:
diff --git a/docs/source/ko/using-diffusers/img2img.md b/docs/source/ko/using-diffusers/img2img.md
index 8da840f748..3901fb755f 100644
--- a/docs/source/ko/using-diffusers/img2img.md
+++ b/docs/source/ko/using-diffusers/img2img.md
@@ -53,11 +53,8 @@ init_image
-
-
-💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.
-
-
+> [!TIP]
+> 💡 `strength`는 입력 이미지에 추가되는 노이즈의 양을 제어하는 0.0에서 1.0 사이의 값입니다. 1.0에 가까운 값은 다양한 변형을 허용하지만 입력 이미지와 의미적으로 일치하지 않는 이미지를 생성합니다.
프롬프트를 정의하고(지브리 스타일(Ghibli-style)에 맞게 조정된 이 체크포인트의 경우 프롬프트 앞에 `ghibli style` 토큰을 붙여야 합니다) 파이프라인을 실행합니다:
diff --git a/docs/source/ko/using-diffusers/inpaint.md b/docs/source/ko/using-diffusers/inpaint.md
index adf1251176..6c0c08bf73 100644
--- a/docs/source/ko/using-diffusers/inpaint.md
+++ b/docs/source/ko/using-diffusers/inpaint.md
@@ -14,7 +14,7 @@ specific language governing permissions and limitations under the License.
[[open-in-colab]]
-[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.
+[`StableDiffusionInpaintPipeline`]은 마스크와 텍스트 프롬프트를 제공하여 이미지의 특정 부분을 편집할 수 있도록 합니다. 이 기능은 인페인팅 작업을 위해 특별히 훈련된 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)과 같은 Stable Diffusion 버전을 사용합니다.
먼저 [`StableDiffusionInpaintPipeline`] 인스턴스를 불러옵니다:
@@ -27,7 +27,7 @@ from io import BytesIO
from diffusers import StableDiffusionInpaintPipeline
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
+ "stable-diffusion-v1-5/stable-diffusion-inpainting",
torch_dtype=torch.float16,
)
pipeline = pipeline.to("cuda")
@@ -59,17 +59,5 @@ image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
:-------------------------:|:-------------------------:|:-------------------------:|-------------------------:|

|

| ***Face of a yellow cat, high resolution, sitting on a park bench*** |

|
-
-
-이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
-
-
-
-아래 Space에서 이미지 인페인팅을 직접 해보세요!
-
-
+> [!WARNING]
+> 이전의 실험적인 인페인팅 구현에서는 품질이 낮은 다른 프로세스를 사용했습니다. 이전 버전과의 호환성을 보장하기 위해 새 모델이 포함되지 않은 사전학습된 파이프라인을 불러오면 이전 인페인팅 방법이 계속 적용됩니다.
diff --git a/docs/source/ko/using-diffusers/kandinsky.md b/docs/source/ko/using-diffusers/kandinsky.md
index cc554c67f9..8eff8f5629 100644
--- a/docs/source/ko/using-diffusers/kandinsky.md
+++ b/docs/source/ko/using-diffusers/kandinsky.md
@@ -31,15 +31,12 @@ Kandinsky 모델은 일련의 다국어 text-to-image 생성 모델입니다. Ka
#!pip install -q diffusers transformers accelerate
```
-
-
-Kandinsky 2.1과 2.2의 사용법은 매우 유사합니다! 유일한 차이점은 Kandinsky 2.2는 latents를 디코딩할 때 `프롬프트`를 입력으로 받지 않는다는 것입니다. 대신, Kandinsky 2.2는 디코딩 중에는 `image_embeds`만 받아들입니다.
-
-
-
-Kandinsky 3는 더 간결한 아키텍처를 가지고 있으며 prior 모델이 필요하지 않습니다. 즉, [Stable Diffusion XL](sdxl)과 같은 다른 diffusion 모델과 사용법이 동일합니다.
-
-
+> [!WARNING]
+> Kandinsky 2.1과 2.2의 사용법은 매우 유사합니다! 유일한 차이점은 Kandinsky 2.2는 latents를 디코딩할 때 `프롬프트`를 입력으로 받지 않는다는 것입니다. 대신, Kandinsky 2.2는 디코딩 중에는 `image_embeds`만 받아들입니다.
+>
+>
+>
+> Kandinsky 3는 더 간결한 아키텍처를 가지고 있으며 prior 모델이 필요하지 않습니다. 즉, [Stable Diffusion XL](sdxl)과 같은 다른 diffusion 모델과 사용법이 동일합니다.
## Text-to-image
@@ -321,20 +318,17 @@ make_image_grid([original_image.resize((512, 512)), image.resize((512, 512))], r
## Inpainting
-
-
-⚠️ Kandinsky 모델은 이제 검은색 픽셀 대신 ⬜️ **흰색 픽셀**을 사용하여 마스크 영역을 표현합니다. 프로덕션에서 [`KandinskyInpaintPipeline`]을 사용하는 경우 흰색 픽셀을 사용하도록 마스크를 변경해야 합니다:
-
-```py
-# PIL 입력에 대해
-import PIL.ImageOps
-mask = PIL.ImageOps.invert(mask)
-
-# PyTorch와 NumPy 입력에 대해
-mask = 1 - mask
-```
-
-
+> [!WARNING]
+> ⚠️ Kandinsky 모델은 이제 검은색 픽셀 대신 ⬜️ **흰색 픽셀**을 사용하여 마스크 영역을 표현합니다. 프로덕션에서 [`KandinskyInpaintPipeline`]을 사용하는 경우 흰색 픽셀을 사용하도록 마스크를 변경해야 합니다:
+>
+> ```py
+> # PIL 입력에 대해
+> import PIL.ImageOps
+> mask = PIL.ImageOps.invert(mask)
+>
+> # PyTorch와 NumPy 입력에 대해
+> mask = 1 - mask
+> ```
인페인팅에서는 원본 이미지, 원본 이미지에서 대체할 영역의 마스크, 인페인팅할 내용에 대한 텍스트 프롬프트가 필요합니다. Prior 파이프라인을 불러옵니다:
@@ -565,11 +559,8 @@ image
## ControlNet
-
-
-⚠️ ControlNet은 Kandinsky 2.2에서만 지원됩니다!
-
-
+> [!WARNING]
+> ⚠️ ControlNet은 Kandinsky 2.2에서만 지원됩니다!
ControlNet을 사용하면 depth map이나 edge detection와 같은 추가 입력을 통해 사전학습된 large diffusion 모델을 conditioning할 수 있습니다. 예를 들어, 모델이 depth map의 구조를 이해하고 보존할 수 있도록 깊이 맵으로 Kandinsky 2.2를 conditioning할 수 있습니다.
diff --git a/docs/source/ko/using-diffusers/loading.md b/docs/source/ko/using-diffusers/loading.md
index 3d6b7634b4..2160acacc2 100644
--- a/docs/source/ko/using-diffusers/loading.md
+++ b/docs/source/ko/using-diffusers/loading.md
@@ -30,11 +30,8 @@ diffusion 모델의 훈련과 추론에 필요한 모든 것은 [`DiffusionPipel
## Diffusion 파이프라인
-
-
-💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면, [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.
-
-
+> [!TIP]
+> 💡 [`DiffusionPipeline`] 클래스가 동작하는 방식에 보다 자세한 내용이 궁금하다면, [DiffusionPipeline explained](#diffusionpipeline에-대해-알아보기) 섹션을 확인해보세요.
[`DiffusionPipeline`] 클래스는 diffusion 모델을 [허브](https://huggingface.co/models?library=diffusers)로부터 불러오는 가장 심플하면서 보편적인 방식입니다. [`DiffusionPipeline.from_pretrained`] 메서드는 적합한 파이프라인 클래스를 자동으로 탐지하고, 필요한 구성요소(configuration)와 가중치(weight) 파일들을 다운로드하고 캐싱한 다음, 해당 파이프라인 인스턴스를 반환합니다.
@@ -175,11 +172,8 @@ Variant란 일반적으로 다음과 같은 체크포인트들을 의미합니
- `torch.float16`과 같이 정밀도는 더 낮지만, 용량 역시 더 작은 부동소수점 타입의 가중치를 사용하는 체크포인트. *(다만 이와 같은 variant의 경우, 추가적인 훈련과 CPU환경에서의 구동이 불가능합니다.)*
- Non-EMA 가중치를 사용하는 체크포인트. *(Non-EMA 가중치의 경우, 파인 튜닝 단계에서 사용하는 것이 권장되는데, 추론 단계에선 사용하지 않는 것이 권장됩니다.)*
-
-
-💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).
-
-
+> [!TIP]
+> 💡 모델 구조는 동일하지만 서로 다른 학습 환경에서 서로 다른 데이터셋으로 학습된 체크포인트들이 있을 경우, 해당 체크포인트들은 variant 단계가 아닌 리포지토리 단계에서 분리되어 관리되어야 합니다. (즉, 해당 체크포인트들은 서로 다른 리포지토리에서 따로 관리되어야 합니다. 예시: [`stable-diffusion-v1-4`], [`stable-diffusion-v1-5`]).
| **checkpoint type** | **weight name** | **argument for loading weights** |
| ------------------- | ----------------------------------- | -------------------------------- |
diff --git a/docs/source/ko/using-diffusers/loading_adapters.md b/docs/source/ko/using-diffusers/loading_adapters.md
index f0d085bc6a..e7ae116575 100644
--- a/docs/source/ko/using-diffusers/loading_adapters.md
+++ b/docs/source/ko/using-diffusers/loading_adapters.md
@@ -18,11 +18,8 @@ specific language governing permissions and limitations under the License.
이 가이드에서는 DreamBooth, textual inversion 및 LoRA 가중치를 불러오는 방법을 설명합니다.
-
-
-사용할 체크포인트와 임베딩은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)에서 찾아보시기 바랍니다.
-
-
+> [!TIP]
+> 사용할 체크포인트와 임베딩은 [Stable Diffusion Conceptualizer](https://huggingface.co/spaces/sd-concepts-library/stable-diffusion-conceptualizer), [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer), [Diffusers Models Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery)에서 찾아보시기 바랍니다.
## DreamBooth
@@ -101,11 +98,8 @@ image
[Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685)은 속도가 빠르고 파일 크기가 (수백 MB로) 작기 때문에 널리 사용되는 학습 기법입니다. 이 가이드의 다른 방법과 마찬가지로, LoRA는 몇 장의 이미지만으로 새로운 스타일을 학습하도록 모델을 학습시킬 수 있습니다. 이는 diffusion 모델에 새로운 가중치를 삽입한 다음 전체 모델 대신 새로운 가중치만 학습시키는 방식으로 작동합니다. 따라서 LoRA를 더 빠르게 학습시키고 더 쉽게 저장할 수 있습니다.
-
-
-LoRA는 다른 학습 방법과 함께 사용할 수 있는 매우 일반적인 학습 기법입니다. 예를 들어, DreamBooth와 LoRA로 모델을 학습하는 것이 일반적입니다. 또한 새롭고 고유한 이미지를 생성하기 위해 여러 개의 LoRA를 불러오고 병합하는 것이 점점 더 일반화되고 있습니다. 병합은 이 불러오기 가이드의 범위를 벗어나므로 자세한 내용은 심층적인 [LoRA 병합](merge_loras) 가이드에서 확인할 수 있습니다.
-
-
+> [!TIP]
+> LoRA는 다른 학습 방법과 함께 사용할 수 있는 매우 일반적인 학습 기법입니다. 예를 들어, DreamBooth와 LoRA로 모델을 학습하는 것이 일반적입니다. 또한 새롭고 고유한 이미지를 생성하기 위해 여러 개의 LoRA를 불러오고 병합하는 것이 점점 더 일반화되고 있습니다. 병합은 이 불러오기 가이드의 범위를 벗어나므로 자세한 내용은 심층적인 [LoRA 병합](merge_loras) 가이드에서 확인할 수 있습니다.
LoRA는 다른 모델과 함께 사용해야 합니다:
@@ -184,11 +178,8 @@ pipe.set_adapters("my_adapter", scales)
이는 여러 어댑터에서도 작동합니다. 방법은 [이 가이드](https://huggingface.co/docs/diffusers/tutorials/using_peft_for_inference#customize-adapters-strength)를 참조하세요.
-
-
-현재 [`~loaders.LoraLoaderMixin.set_adapters`]는 어텐션 가중치의 스케일링만 지원합니다. LoRA에 다른 부분(예: resnets or down-/upsamplers)이 있는 경우 1.0의 스케일을 유지합니다.
-
-
+> [!WARNING]
+> 현재 [`~loaders.LoraLoaderMixin.set_adapters`]는 어텐션 가중치의 스케일링만 지원합니다. LoRA에 다른 부분(예: resnets or down-/upsamplers)이 있는 경우 1.0의 스케일을 유지합니다.
### Kohya와 TheLastBen
@@ -222,14 +213,11 @@ image = pipeline(prompt).images[0]
image
```
-
-
-Kohya LoRA를 🤗 Diffusers와 함께 사용할 때 몇 가지 제한 사항이 있습니다:
-
-- [여기](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736)에 설명된 여러 가지 이유로 인해 이미지가 ComfyUI와 같은 UI에서 생성된 이미지와 다르게 보일 수 있습니다.
-- [LyCORIS 체크포인트](https://github.com/KohakuBlueleaf/LyCORIS)가 완전히 지원되지 않습니다. [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 및 LoCon 모듈로 LyCORIS 체크포인트를 불러올 수 있지만, Hada 및 LoKR은 지원되지 않습니다.
-
-
+> [!WARNING]
+> Kohya LoRA를 🤗 Diffusers와 함께 사용할 때 몇 가지 제한 사항이 있습니다:
+>
+> - [여기](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736)에 설명된 여러 가지 이유로 인해 이미지가 ComfyUI와 같은 UI에서 생성된 이미지와 다르게 보일 수 있습니다.
+> - [LyCORIS 체크포인트](https://github.com/KohakuBlueleaf/LyCORIS)가 완전히 지원되지 않습니다. [`~loaders.LoraLoaderMixin.load_lora_weights`] 메서드는 LoRA 및 LoCon 모듈로 LyCORIS 체크포인트를 불러올 수 있지만, Hada 및 LoKR은 지원되지 않습니다.
@@ -326,9 +314,8 @@ pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name=
IP-Adapter FaceID 모델은 CLIP 이미지 임베딩 대신 `insightface`에서 생성한 이미지 임베딩을 사용하는 실험적인 IP Adapter입니다. 이러한 모델 중 일부는 LoRA를 사용하여 ID 일관성을 개선하기도 합니다.
이러한 모델을 사용하려면 `insightface`와 해당 요구 사항을 모두 설치해야 합니다.
-
-InsightFace 사전학습된 모델은 비상업적 연구 목적으로만 사용할 수 있으므로, IP-Adapter-FaceID 모델은 연구 목적으로만 릴리즈되었으며 상업적 용도로는 사용할 수 없습니다.
-
+> [!WARNING]
+> InsightFace 사전학습된 모델은 비상업적 연구 목적으로만 사용할 수 있으므로, IP-Adapter-FaceID 모델은 연구 목적으로만 릴리즈되었으며 상업적 용도로는 사용할 수 없습니다.
```py
pipeline = AutoPipelineForText2Image.from_pretrained(
diff --git a/docs/source/ko/using-diffusers/other-formats.md b/docs/source/ko/using-diffusers/other-formats.md
index 3034551f48..f5a71f56eb 100644
--- a/docs/source/ko/using-diffusers/other-formats.md
+++ b/docs/source/ko/using-diffusers/other-formats.md
@@ -14,11 +14,8 @@ specific language governing permissions and limitations under the License.
Stable Diffusion 모델들은 학습 및 저장된 프레임워크와 다운로드 위치에 따라 다양한 형식으로 제공됩니다. 이러한 형식을 🤗 Diffusers에서 사용할 수 있도록 변환하면 추론을 위한 [다양한 스케줄러 사용](schedulers), 사용자 지정 파이프라인 구축, 추론 속도 최적화를 위한 다양한 기법과 방법 등 라이브러리에서 지원하는 모든 기능을 사용할 수 있습니다.
-
-
-우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)
-
-
+> [!TIP]
+> 우리는 `.safetensors` 형식을 추천합니다. 왜냐하면 기존의 pickled 파일은 취약하고 머신에서 코드를 실행할 때 악용될 수 있는 것에 비해 훨씬 더 안전합니다. (safetensors 불러오기 가이드에서 자세히 알아보세요.)
이 가이드에서는 다른 Stable Diffusion 형식을 🤗 Diffusers와 호환되도록 변환하는 방법을 설명합니다.
diff --git a/docs/source/ko/using-diffusers/schedulers.md b/docs/source/ko/using-diffusers/schedulers.md
index 55424c9982..b12c08b8c8 100644
--- a/docs/source/ko/using-diffusers/schedulers.md
+++ b/docs/source/ko/using-diffusers/schedulers.md
@@ -318,12 +318,9 @@ images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```
-
-
-다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.
-
-- `FlaxLMSDiscreteScheduler`
-- `FlaxDDPMScheduler`
-
-
+> [!WARNING]
+> 다음 Flax 스케줄러는 *아직* Flax Stable Diffusion 파이프라인과 호환되지 않습니다.
+>
+> - `FlaxLMSDiscreteScheduler`
+> - `FlaxDDPMScheduler`
diff --git a/docs/source/ko/using-diffusers/shap-e.md b/docs/source/ko/using-diffusers/shap-e.md
index abf5a182b3..4c9d7fb7d1 100644
--- a/docs/source/ko/using-diffusers/shap-e.md
+++ b/docs/source/ko/using-diffusers/shap-e.md
@@ -151,11 +151,8 @@ images = pipe(prompt, guidance_scale=guidance_scale, num_inference_steps=64, fra
메시 출력을 `ply` 파일로 저장하려면 [`~utils.export_to_ply`] 함수를 사용합니다:
-
-
-선택적으로 [`~utils.export_to_obj`] 함수를 사용하여 메시 출력을 `obj` 파일로 저장할 수 있습니다. 다양한 형식으로 메시 출력을 저장할 수 있어 다운스트림에서 더욱 유연하게 사용할 수 있습니다!
-
-
+> [!TIP]
+> 선택적으로 [`~utils.export_to_obj`] 함수를 사용하여 메시 출력을 `obj` 파일로 저장할 수 있습니다. 다양한 형식으로 메시 출력을 저장할 수 있어 다운스트림에서 더욱 유연하게 사용할 수 있습니다!
```py
from diffusers.utils import export_to_ply
diff --git a/docs/source/ko/using-diffusers/unconditional_image_generation.md b/docs/source/ko/using-diffusers/unconditional_image_generation.md
index c3eaac4b03..b8fe800578 100644
--- a/docs/source/ko/using-diffusers/unconditional_image_generation.md
+++ b/docs/source/ko/using-diffusers/unconditional_image_generation.md
@@ -20,11 +20,8 @@ Unconditional 이미지 생성은 비교적 간단한 작업입니다. 모델이
먼저 ['DiffusionPipeline']의 인스턴스를 생성하고 다운로드할 파이프라인의 [체크포인트](https://huggingface.co/models?library=diffusers&sort=downloads)를 지정합니다. 허브의 🧨 diffusion 체크포인트 중 하나를 사용할 수 있습니다(사용할 체크포인트는 나비 이미지를 생성합니다).
-
-
-💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.
-
-
+> [!TIP]
+> 💡 나만의 unconditional 이미지 생성 모델을 학습시키고 싶으신가요? 학습 가이드를 살펴보고 나만의 이미지를 생성하는 방법을 알아보세요.
이 가이드에서는 unconditional 이미지 생성에 ['DiffusionPipeline']과 [DDPM](https://huggingface.co/papers/2006.11239)을 사용합니다:
diff --git a/docs/source/ko/using-diffusers/write_own_pipeline.md b/docs/source/ko/using-diffusers/write_own_pipeline.md
index 45678763cc..ae6ce238ac 100644
--- a/docs/source/ko/using-diffusers/write_own_pipeline.md
+++ b/docs/source/ko/using-diffusers/write_own_pipeline.md
@@ -110,11 +110,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
보시다시피, 이것은 UNet 모델만 포함된 DDPM 파이프라인보다 더 복잡합니다. Stable Diffusion 모델에는 세 개의 개별 사전학습된 모델이 있습니다.
-
-
-💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.
-
-
+> [!TIP]
+> 💡 VAE, UNet 및 텍스트 인코더 모델의 작동방식에 대한 자세한 내용은 [How does Stable Diffusion work?](https://huggingface.co/blog/stable_diffusion#how-does-stable-diffusion-work) 블로그를 참조하세요.
이제 Stable Diffusion 파이프라인에 필요한 구성요소들이 무엇인지 알았으니, [`~ModelMixin.from_pretrained`] 메서드를 사용해 모든 구성요소를 불러옵니다. 사전학습된 체크포인트 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)에서 찾을 수 있으며, 각 구성요소들은 별도의 하위 폴더에 저장되어 있습니다:
@@ -151,11 +148,8 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
다음 단계는 임베딩을 생성하기 위해 텍스트를 토큰화하는 것입니다. 이 텍스트는 UNet 모델에서 condition으로 사용되고 입력 프롬프트와 유사한 방향으로 diffusion 프로세스를 조정하는 데 사용됩니다.
-
-
-💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.
-
-
+> [!TIP]
+> 💡 `guidance_scale` 매개변수는 이미지를 생성할 때 프롬프트에 얼마나 많은 가중치를 부여할지 결정합니다.
다른 프롬프트를 생성하고 싶다면 원하는 프롬프트를 자유롭게 선택하세요!
@@ -198,15 +192,12 @@ Stable Diffusion 은 text-to-image *latent diffusion* 모델입니다. latent di
그다음 diffusion 프로세스의 시작점으로 초기 랜덤 노이즈를 생성합니다. 이것이 이미지의 잠재적 표현이며 점차적으로 노이즈가 제거됩니다. 이 시점에서 `latent` 이미지는 최종 이미지 크기보다 작지만 나중에 모델이 이를 512x512 이미지 크기로 변환하므로 괜찮습니다.
-
-
-💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:
-
-```py
-2 ** (len(vae.config.block_out_channels) - 1) == 8
-```
-
-
+> [!TIP]
+> 💡 `vae` 모델에는 3개의 다운 샘플링 레이어가 있기 때문에 높이와 너비가 8로 나뉩니다. 다음을 실행하여 확인할 수 있습니다:
+>
+> ```py
+> 2 ** (len(vae.config.block_out_channels) - 1) == 8
+> ```
```py
>>> latents = torch.randn(
diff --git a/docs/source/pt/installation.md b/docs/source/pt/installation.md
index 1e83e36ca1..acc767110c 100644
--- a/docs/source/pt/installation.md
+++ b/docs/source/pt/installation.md
@@ -104,11 +104,8 @@ Esses comandos irá linkar a pasta que você clonou o repositório e os caminhos
Python então irá procurar dentro da pasta que você clonou além dos caminhos normais das bibliotecas.
Por exemplo, se o pacote python for tipicamente instalado no `~/anaconda3/envs/main/lib/python3.10/site-packages/`, o Python também irá procurar na pasta `~/diffusers/` que você clonou.
-
-
-Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
-
-
+> [!WARNING]
+> Você deve deixar a pasta `diffusers` se você quiser continuar usando a biblioteca.
Agora você pode facilmente atualizar seu clone para a última versão do 🤗 Diffusers com o seguinte comando:
diff --git a/docs/source/pt/quicktour.md b/docs/source/pt/quicktour.md
index 109f7e2712..5996b65a9c 100644
--- a/docs/source/pt/quicktour.md
+++ b/docs/source/pt/quicktour.md
@@ -24,11 +24,8 @@ Seja você um desenvolvedor ou um usuário, esse tour rápido irá introduzir vo
Esse tour rápido mostrará como usar o [`DiffusionPipeline`] para inferência, e então mostrará como combinar um modelo e um agendador para replicar o que está acontecendo dentro do [`DiffusionPipeline`].
-
-
-Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!
-
-
+> [!TIP]
+> Esse tour rápido é uma versão simplificada da introdução 🧨 Diffusers [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb) para ajudar você a começar rápido. Se você quer aprender mais sobre o objetivo do 🧨 Diffusers, filosofia de design, e detalhes adicionais sobre a API principal, veja o notebook!
Antes de começar, certifique-se de ter todas as bibliotecas necessárias instaladas:
@@ -56,11 +53,8 @@ Comece criando uma instância do [`DiffusionPipeline`] e especifique qual checkp
Você pode usar o [`DiffusionPipeline`] para qualquer [checkpoint](https://huggingface.co/models?library=diffusers&sort=downloads) armazenado no Hugging Face Hub.
Nesse quicktour, você carregará o checkpoint [`stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) para geração de texto para imagem.
-
-
-Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.
-
-
+> [!WARNING]
+> Para os modelos de [Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion), por favor leia cuidadosamente a [licença](https://huggingface.co/spaces/CompVis/stable-diffusion-license) primeiro antes de rodar o modelo. 🧨 Diffusers implementa uma verificação de segurança: [`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) para prevenir conteúdo ofensivo ou nocivo, mas as capacidades de geração de imagem aprimorada do modelo podem ainda produzir conteúdo potencialmente nocivo.
Para carregar o modelo com o método [`~DiffusionPipeline.from_pretrained`]:
@@ -204,11 +198,8 @@ Para geração de exemplos reais, você precisará de um agendador para guiar o
Agendadores gerenciam a retirada do ruído de uma amostra ruidosa para uma amostra menos ruidosa dado a saída do modelo - nesse caso, é o `noisy_residual`.
-
-
-🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.
-
-
+> [!TIP]
+> 🧨 Diffusers é uma caixa de ferramentas para construir sistemas de difusão. Enquanto o [`DiffusionPipeline`] é uma forma conveniente de começar com um sistema de difusão pré-construído, você também pode escolher seus próprios modelos e agendadores separadamente para construir um sistema de difusão personalizado.
Para o tour rápido, você irá instanciar o [`DDPMScheduler`] com o método [`~diffusers.ConfigMixin.from_config`]:
@@ -232,11 +223,8 @@ DDPMScheduler {
}
```
-
-
-💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!
-
-
+> [!TIP]
+> 💡 Perceba como o agendador é instanciado de uma configuração. Diferentemente de um modelo, um agendador não tem pesos treináveis e é livre de parâmetros!
Um dos parâmetros mais importante são:
diff --git a/docs/source/zh/conceptual/evaluation.md b/docs/source/zh/conceptual/evaluation.md
index e809c8730d..770d197be0 100644
--- a/docs/source/zh/conceptual/evaluation.md
+++ b/docs/source/zh/conceptual/evaluation.md
@@ -92,11 +92,8 @@ images = sd_pipeline(sample_prompts, num_images_per_prompt=1, generator=generato
当使用多个待评估模型为所有提示词生成若干图像后,这些结果将提交给人类评估员进行打分。有关DrawBench和PartiPrompts基准测试的更多细节,请参阅各自的论文。
-
-
-在模型训练过程中查看推理样本有助于评估训练进度。我们的[训练脚本](https://github.com/huggingface/diffusers/tree/main/examples/)支持此功能,并额外提供TensorBoard和Weights & Biases日志记录功能。
-
-
+> [!TIP]
+> 在模型训练过程中查看推理样本有助于评估训练进度。我们的[训练脚本](https://github.com/huggingface/diffusers/tree/main/examples/)支持此功能,并额外提供TensorBoard和Weights & Biases日志记录功能。
## 定量评估
@@ -189,11 +186,8 @@ print(f"v-1-5版本的CLIP分数: {sd_clip_score_1_5}")
结果表明[v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)检查点性能优于前代。但需注意,我们用于计算CLIP分数的提示词数量较少。实际评估时应使用更多样化且数量更大的提示词集。
-
-
-该分数存在固有局限性:训练数据中的标题是从网络爬取,并提取自图片关联的`alt`等标签。这些描述未必符合人类描述图像的方式,因此我们需要人工"设计"部分提示词。
-
-
+> [!WARNING]
+> 该分数存在固有局限性:训练数据中的标题是从网络爬取,并提取自图片关联的`alt`等标签。这些描述未必符合人类描述图像的方式,因此我们需要人工"设计"部分提示词。
### 图像条件式文本生成图像
@@ -402,11 +396,8 @@ print(f"CLIP方向相似度: {np.mean(scores)}")
该度量方法同样适用于类似流程,例如[`StableDiffusionPix2PixZeroPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/pix2pix_zero#diffusers.StableDiffusionPix2PixZeroPipeline)。
-
-
-CLIP分数和CLIP方向相似度都依赖CLIP模型,可能导致评估结果存在偏差。
-
-
+> [!TIP]
+> CLIP分数和CLIP方向相似度都依赖CLIP模型,可能导致评估结果存在偏差。
***扩展IS、FID(后文讨论)或KID等指标存在困难***,当被评估模型是在大型图文数据集(如[LAION-5B数据集](https://laion.ai/blog/laion-5b/))上预训练时。因为这些指标的底层都使用了在ImageNet-1k数据集上预训练的InceptionNet来提取图像特征。Stable Diffusion的预训练数据集与InceptionNet的预训练数据集可能重叠有限,因此不适合作为特征提取器。
@@ -536,19 +527,16 @@ FID分数越低越好。以下因素会影响FID结果:
对于最后两点,最佳实践是使用不同的随机种子和推理步数进行多次评估,然后报告平均结果。
-
-
-FID结果往往具有脆弱性,因为它依赖于许多因素:
-
-* 计算过程中使用的特定Inception模型
-* 计算实现的准确性
-* 图像格式(PNG和JPG的起点不同)
-
-需要注意的是,FID通常在比较相似实验时最有用,但除非作者仔细公开FID测量代码,否则很难复现论文结果。
-
-这些注意事项同样适用于其他相关指标,如KID和IS。
-
-
+> [!WARNING]
+> FID结果往往具有脆弱性,因为它依赖于许多因素:
+>
+> * 计算过程中使用的特定Inception模型
+> * 计算实现的准确性
+> * 图像格式(PNG和JPG的起点不同)
+>
+> 需要注意的是,FID通常在比较相似实验时最有用,但除非作者仔细公开FID测量代码,否则很难复现论文结果。
+>
+> 这些注意事项同样适用于其他相关指标,如KID和IS。
最后,让我们可视化检查这些`fake_images`。
diff --git a/docs/source/zh/installation.md b/docs/source/zh/installation.md
index fc77ea8c48..9941ed24ae 100644
--- a/docs/source/zh/installation.md
+++ b/docs/source/zh/installation.md
@@ -109,11 +109,8 @@ pip install -e ".[flax]"
现在,不只是在通常的库路径,Python 还会在你克隆的文件夹内寻找包。
例如,如果你的 Python 包通常安装在 `~/anaconda3/envs/main/lib/python3.10/Site-packages/`,Python 也会搜索你克隆到的文件夹。`~/diffusers/`。
-
-
-如果你想继续使用这个库,你必须保留 `diffusers` 文件夹。
-
-
+> [!WARNING]
+> 如果你想继续使用这个库,你必须保留 `diffusers` 文件夹。
现在你可以用下面的命令轻松地将你克隆的 🤗 Diffusers 库更新到最新版本。
diff --git a/docs/source/zh/optimization/coreml.md b/docs/source/zh/optimization/coreml.md
index 1d78866720..3926a5ddb0 100644
--- a/docs/source/zh/optimization/coreml.md
+++ b/docs/source/zh/optimization/coreml.md
@@ -13,11 +13,8 @@ http://www.apache.org/licenses/LICENSE-2.0
Core ML 模型可以利用 Apple 设备中所有可用的计算引擎:CPU、GPU 和 Apple Neural Engine(或 ANE,一种在 Apple Silicon Mac 和现代 iPhone/iPad 中可用的张量优化加速器)。根据模型及其运行的设备,Core ML 还可以混合和匹配计算引擎,例如,模型的某些部分可能在 CPU 上运行,而其他部分在 GPU 上运行。
-
-
-您还可以使用 PyTorch 内置的 `mps` 加速器在 Apple Silicon Mac 上运行 `diffusers` Python 代码库。这种方法在 [mps 指南](mps) 中有详细解释,但它与原生应用不兼容。
-
-
+> [!TIP]
+> 您还可以使用 PyTorch 内置的 `mps` 加速器在 Apple Silicon Mac 上运行 `diffusers` Python 代码库。这种方法在 [mps 指南](mps) 中有详细解释,但它与原生应用不兼容。
## Stable Diffusion Core ML 检查点
diff --git a/docs/source/zh/optimization/fp16.md b/docs/source/zh/optimization/fp16.md
index 1088482d24..e1c4c7e57a 100644
--- a/docs/source/zh/optimization/fp16.md
+++ b/docs/source/zh/optimization/fp16.md
@@ -238,11 +238,8 @@ pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph
一般来说,`sigmas`应该[保持在CPU上](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240),以避免通信同步和延迟。
-
-
-参阅[torch.compile和Diffusers:峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何为扩散模型最大化`torch.compile`的性能。
-
-
+> [!TIP]
+> 参阅[torch.compile和Diffusers:峰值性能实践指南](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/)博客文章,了解如何为扩散模型最大化`torch.compile`的性能。
### 基准测试
diff --git a/docs/source/zh/optimization/mps.md b/docs/source/zh/optimization/mps.md
index c76a475336..48b08c5a12 100644
--- a/docs/source/zh/optimization/mps.md
+++ b/docs/source/zh/optimization/mps.md
@@ -35,11 +35,8 @@ image = pipe(prompt).images[0]
image
```
-
-
-PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) 后端不支持大小超过 `2**32` 的 NDArray。如果您遇到此问题,请提交 [Issue](https://github.com/huggingface/diffusers/issues/new/choose) 以便我们调查。
-
-
+> [!WARNING]
+> PyTorch [mps](https://pytorch.org/docs/stable/notes/mps.html) 后端不支持大小超过 `2**32` 的 NDArray。如果您遇到此问题,请提交 [Issue](https://github.com/huggingface/diffusers/issues/new/choose) 以便我们调查。
如果您使用 **PyTorch 1.13**,您需要通过管道进行一次额外的"预热"传递。这是一个临时解决方法,用于解决首次推理传递产生的结果与后续传递略有不同的问题。您只需要执行此传递一次,并且在仅进行一次推理步骤后可以丢弃结果。
diff --git a/docs/source/zh/optimization/neuron.md b/docs/source/zh/optimization/neuron.md
index 709404d56b..99d807a88c 100644
--- a/docs/source/zh/optimization/neuron.md
+++ b/docs/source/zh/optimization/neuron.md
@@ -17,11 +17,8 @@ Diffusers 功能可在 [AWS Inf2 实例](https://aws.amazon.com/ec2/instance-typ
python -m pip install --upgrade-strategy eager optimum[neuronx]
```
-
-
-我们提供预构建的 [Hugging Face Neuron 深度学习 AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2)(DLAMI)和用于 Amazon SageMaker 的 Optimum Neuron 容器。建议正确设置您的环境。
-
-
+> [!TIP]
+> 我们提供预构建的 [Hugging Face Neuron 深度学习 AMI](https://aws.amazon.com/marketplace/pp/prodview-gr3e6yiscria2)(DLAMI)和用于 Amazon SageMaker 的 Optimum Neuron 容器。建议正确设置您的环境。
下面的示例演示了如何在 inf2.8xlarge 实例上使用 Stable Diffusion XL 模型生成图像(一旦模型编译完成,您可以切换到更便宜的 inf2.xlarge 实例)。要生成一些图像,请使用 [`~optimum.neuron.NeuronStableDiffusionXLPipeline`] 类,该类类似于 Diffusers 中的 [`StableDiffusionXLPipeline`] 类。
diff --git a/docs/source/zh/optimization/onnx.md b/docs/source/zh/optimization/onnx.md
index 4b3804d015..b70510d51b 100644
--- a/docs/source/zh/optimization/onnx.md
+++ b/docs/source/zh/optimization/onnx.md
@@ -31,11 +31,8 @@ image = pipeline(prompt).images[0]
pipeline.save_pretrained("./onnx-stable-diffusion-v1-5")
```
-
-
-当前批量生成多个提示可能会占用过高内存。在问题修复前,建议采用迭代方式而非批量处理。
-
-
+> [!WARNING]
+> 当前批量生成多个提示可能会占用过高内存。在问题修复前,建议采用迭代方式而非批量处理。
如需离线导出 ONNX 格式流水线供后续推理使用,请使用 [`optimum-cli export`](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#exporting-a-model-to-onnx-using-the-cli) 命令:
diff --git a/docs/source/zh/optimization/xformers.md b/docs/source/zh/optimization/xformers.md
index 9902feeee6..2a3a3d8341 100644
--- a/docs/source/zh/optimization/xformers.md
+++ b/docs/source/zh/optimization/xformers.md
@@ -17,16 +17,10 @@ http://www.apache.org/licenses/LICENSE-2.0
pip install xformers
```
-
-
-xFormers的`pip`安装包需要最新版本的PyTorch。如需使用旧版PyTorch,建议[从源码安装xFormers](https://github.com/facebookresearch/xformers#installing-xformers)。
-
-
+> [!TIP]
+> xFormers的`pip`安装包需要最新版本的PyTorch。如需使用旧版PyTorch,建议[从源码安装xFormers](https://github.com/facebookresearch/xformers#installing-xformers)。
安装完成后,您可调用`enable_xformers_memory_efficient_attention()`来实现更快的推理速度和更低的内存占用,具体用法参见[此章节](memory#memory-efficient-attention)。
-
-
-根据[此问题](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)反馈,xFormers `v0.0.16`版本在某些GPU上无法用于训练(微调或DreamBooth)。如遇此问题,请按照该issue评论区指引安装开发版本。
-
-
\ No newline at end of file
+> [!WARNING]
+> 根据[此问题](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212)反馈,xFormers `v0.0.16`版本在某些GPU上无法用于训练(微调或DreamBooth)。如遇此问题,请按照该issue评论区指引安装开发版本。
\ No newline at end of file
diff --git a/docs/source/zh/quicktour.md b/docs/source/zh/quicktour.md
index 08efaa87d2..2b8803384f 100644
--- a/docs/source/zh/quicktour.md
+++ b/docs/source/zh/quicktour.md
@@ -31,11 +31,8 @@ specific language governing permissions and limitations under the License.
快速入门将告诉你如何使用[`DiffusionPipeline`]进行推理,然后指导你如何结合模型和调度器以复现[`DiffusionPipeline`]内部发生的事情。
-
-
-快速入门是🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)的简化版,可以帮助你快速上手。如果你想了解更多关于🧨 Diffusers的目标、设计理念以及关于它的核心API的更多细节,可以点击🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)查看。
-
-
+> [!TIP]
+> 快速入门是🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)的简化版,可以帮助你快速上手。如果你想了解更多关于🧨 Diffusers的目标、设计理念以及关于它的核心API的更多细节,可以点击🧨[Diffusers入门](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/diffusers_intro.ipynb)查看。
在开始之前,确认一下你已经安装好了所需要的库:
@@ -66,11 +63,10 @@ pip install --upgrade diffusers accelerate transformers
您可以在Hugging Face Hub上使用[DiffusionPipeline]的任何检查点。
在本快速入门中,您将加载stable-diffusion-v1-5检查点,用于文本到图像生成。
-。
-
-对于[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion)模型,在运行该模型之前,请先仔细阅读[许可证](https://huggingface.co/spaces/CompVis/stable-diffusion-license)。🧨 Diffusers实现了一个[`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py),以防止有攻击性的或有害的内容,但Stable Diffusion模型改进图像的生成能力仍有可能产生潜在的有害内容。
-
-
+> [!WARNING]
+> 。
+>
+> 对于[Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion)模型,在运行该模型之前,请先仔细阅读[许可证](https://huggingface.co/spaces/CompVis/stable-diffusion-license)。🧨 Diffusers实现了一个[`safety_checker`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py),以防止有攻击性的或有害的内容,但Stable Diffusion模型改进图像的生成能力仍有可能产生潜在的有害内容。
用[`~DiffusionPipeline.from_pretrained`]方法加载模型。
@@ -221,11 +217,8 @@ torch.Size([1, 3, 256, 256])
-
-
-🧨 Diffusers是一个用于构建扩散系统的工具箱。预定义好的扩散系统[`DiffusionPipeline`]能方便你快速试用,你也可以单独选择自己的模型和调度器组件来建立一个自定义的扩散系统。
-
-
+> [!TIP]
+> 🧨 Diffusers是一个用于构建扩散系统的工具箱。预定义好的扩散系统[`DiffusionPipeline`]能方便你快速试用,你也可以单独选择自己的模型和调度器组件来建立一个自定义的扩散系统。
在快速入门教程中,你将用它的[`~diffusers.ConfigMixin.from_config`]方法实例化[`DDPMScheduler`]:
@@ -249,12 +242,8 @@ DDPMScheduler {
}
```
-
-
-
-💡 注意调度器是如何从配置中实例化的。与模型不同,调度器没有可训练的权重,而且是无参数的。
-
-
+> [!TIP]
+> 💡 注意调度器是如何从配置中实例化的。与模型不同,调度器没有可训练的权重,而且是无参数的。
* `num_train_timesteps`:去噪过程的长度,或者换句话说,将随机高斯噪声处理成数据样本所需的时间步数。
* `beta_schedule`:用于推理和训练的噪声表。
diff --git a/docs/source/zh/stable_diffusion.md b/docs/source/zh/stable_diffusion.md
index bf9288c5b7..d337fb41a0 100644
--- a/docs/source/zh/stable_diffusion.md
+++ b/docs/source/zh/stable_diffusion.md
@@ -1,264 +1,258 @@
-
-
-# 有效且高效的扩散
-
-[[open-in-colab]]
-
-让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
-
-这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
-
-
-本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
-
-
-首先,加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 模型:
-
-```python
-from diffusers import DiffusionPipeline
-
-model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
-pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
-```
-
-本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
-
-```python
-prompt = "portrait photo of a old warrior chief"
-```
-
-## 速度
-
-
-
-💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
-
-
-
-加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
-
-```python
-pipeline = pipeline.to("cuda")
-```
-
-为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reusing_seeds):
-
-```python
-import torch
-
-generator = torch.Generator("cuda").manual_seed(0)
-```
-
-现在,你可以生成一个图像:
-
-```python
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-

-
-
-在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
-
-
-让我们把模型的精度降低至 `float16` ,然后生成一张图像:
-
-```python
-import torch
-
-pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
-pipeline = pipeline.to("cuda")
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator).images[0]
-image
-```
-
-
-

-
-
-这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
-
-
-
-💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
-
-
-
-另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
-
-```python
-pipeline.scheduler.compatibles
-[
- diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
- diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
- diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
- diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
- diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
- diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
- diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
- diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
- diffusers.schedulers.scheduling_pndm.PNDMScheduler,
- diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
- diffusers.schedulers.scheduling_ddim.DDIMScheduler,
-]
-```
-
-Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
-
-```python
-from diffusers import DPMSolverMultistepScheduler
-
-pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
-```
-
-现在将 `num_inference_steps` 设置为 20:
-
-```python
-generator = torch.Generator("cuda").manual_seed(0)
-image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
-image
-```
-
-
-

-
-
-太棒了!你成功把推理时间缩短到 4 秒!⚡️
-
-## 内存
-
-改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
-
-创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
-
-
-```python
-def get_inputs(batch_size=1):
- generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
- prompts = batch_size * [prompt]
- num_inference_steps = 20
-
- return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
-```
-
-设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
-
-```python
-from diffusers.utils import make_image_grid
-
-images = pipeline(**get_inputs(batch_size=4)).images
-make_image_grid(images, 2, 2)
-```
-
-除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
-
-```python
-pipeline.enable_attention_slicing()
-```
-
-现在尝试把 `batch_size` 增加到 8!
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
-
-## 质量
-
-在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
-
-### 更好的 checkpoints
-
-有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
-
-随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
-
-### 更好的 pipeline 组件
-
-也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
-
-```python
-from diffusers import AutoencoderKL
-
-vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
-pipeline.vae = vae
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-### 更好的提示词工程
-
-用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
-
-- 我想生成的图像或类似图像如何存储在互联网上?
-- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
-
-考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
-
-```python
-prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
-prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
-```
-
-使用新的提示词生成一批图像:
-
-```python
-images = pipeline(**get_inputs(batch_size=8)).images
-make_image_grid(images, rows=2, cols=4)
-```
-
-
-

-
-
-非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
-
-```python
-prompts = [
- "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
- "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
-]
-
-generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
-images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
-make_image_grid(images, 2, 2)
-```
-
-
-

-
-
-## 最后
-
-在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
-
-- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
-- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
-- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
+
+
+# 有效且高效的扩散
+
+[[open-in-colab]]
+
+让 [`DiffusionPipeline`] 生成特定风格或包含你所想要的内容的图像可能会有些棘手。 通常情况下,你需要多次运行 [`DiffusionPipeline`] 才能得到满意的图像。但是从无到有生成图像是一个计算密集的过程,特别是如果你要一遍又一遍地进行推理运算。
+
+这就是为什么从pipeline中获得最高的 *computational* (speed) 和 *memory* (GPU RAM) 非常重要 ,以减少推理周期之间的时间,从而使迭代速度更快。
+
+
+本教程将指导您如何通过 [`DiffusionPipeline`] 更快、更好地生成图像。
+
+
+首先,加载 [`stable-diffusion-v1-5/stable-diffusion-v1-5`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) 模型:
+
+```python
+from diffusers import DiffusionPipeline
+
+model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+pipeline = DiffusionPipeline.from_pretrained(model_id, use_safetensors=True)
+```
+
+本教程将使用的提示词是 [`portrait photo of a old warrior chief`] ,但是你可以随心所欲的想象和构造自己的提示词:
+
+```python
+prompt = "portrait photo of a old warrior chief"
+```
+
+## 速度
+
+> [!TIP]
+> 💡 如果你没有 GPU, 你可以从像 [Colab](https://colab.research.google.com/) 这样的 GPU 提供商获取免费的 GPU !
+
+加速推理的最简单方法之一是将 pipeline 放在 GPU 上 ,就像使用任何 PyTorch 模块一样:
+
+```python
+pipeline = pipeline.to("cuda")
+```
+
+为了确保您可以使用相同的图像并对其进行改进,使用 [`Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) 方法,然后设置一个随机数种子 以确保其 [复现性](./using-diffusers/reusing_seeds):
+
+```python
+import torch
+
+generator = torch.Generator("cuda").manual_seed(0)
+```
+
+现在,你可以生成一个图像:
+
+```python
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+

+
+
+在 T4 GPU 上,这个过程大概要30秒(如果你的 GPU 比 T4 好,可能会更快)。在默认情况下,[`DiffusionPipeline`] 使用完整的 `float32` 精度进行 50 步推理。你可以通过降低精度(如 `float16` )或者减少推理步数来加速整个过程
+
+
+让我们把模型的精度降低至 `float16` ,然后生成一张图像:
+
+```python
+import torch
+
+pipeline = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, use_safetensors=True)
+pipeline = pipeline.to("cuda")
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator).images[0]
+image
+```
+
+
+

+
+
+这一次,生成图像只花了约 11 秒,比之前快了近 3 倍!
+
+> [!TIP]
+> 💡 我们强烈建议把 pipeline 精度降低至 `float16` , 到目前为止, 我们很少看到输出质量有任何下降。
+
+另一个选择是减少推理步数。 你可以选择一个更高效的调度器 (*scheduler*) 可以减少推理步数同时保证输出质量。您可以在 [DiffusionPipeline] 中通过调用compatibles方法找到与当前模型兼容的调度器 (*scheduler*)。
+
+```python
+pipeline.scheduler.compatibles
+[
+ diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler,
+ diffusers.schedulers.scheduling_unipc_multistep.UniPCMultistepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_discrete.KDPM2DiscreteScheduler,
+ diffusers.schedulers.scheduling_deis_multistep.DEISMultistepScheduler,
+ diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler,
+ diffusers.schedulers.scheduling_ddpm.DDPMScheduler,
+ diffusers.schedulers.scheduling_dpmsolver_singlestep.DPMSolverSinglestepScheduler,
+ diffusers.schedulers.scheduling_k_dpm_2_ancestral_discrete.KDPM2AncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_heun_discrete.HeunDiscreteScheduler,
+ diffusers.schedulers.scheduling_pndm.PNDMScheduler,
+ diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler,
+ diffusers.schedulers.scheduling_ddim.DDIMScheduler,
+]
+```
+
+Stable Diffusion 模型默认使用的是 [`PNDMScheduler`] ,通常要大概50步推理, 但是像 [`DPMSolverMultistepScheduler`] 这样更高效的调度器只要大概 20 或 25 步推理. 使用 [`ConfigMixin.from_config`] 方法加载新的调度器:
+
+```python
+from diffusers import DPMSolverMultistepScheduler
+
+pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
+```
+
+现在将 `num_inference_steps` 设置为 20:
+
+```python
+generator = torch.Generator("cuda").manual_seed(0)
+image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0]
+image
+```
+
+
+

+
+
+太棒了!你成功把推理时间缩短到 4 秒!⚡️
+
+## 内存
+
+改善 pipeline 性能的另一个关键是减少内存的使用量,这间接意味着速度更快,因为你经常试图最大化每秒生成的图像数量。要想知道你一次可以生成多少张图片,最简单的方法是尝试不同的batch size,直到出现`OutOfMemoryError` (OOM)。
+
+创建一个函数,为每一批要生成的图像分配提示词和 `Generators` 。请务必为每个`Generator` 分配一个种子,以便于复现良好的结果。
+
+
+```python
+def get_inputs(batch_size=1):
+ generator = [torch.Generator("cuda").manual_seed(i) for i in range(batch_size)]
+ prompts = batch_size * [prompt]
+ num_inference_steps = 20
+
+ return {"prompt": prompts, "generator": generator, "num_inference_steps": num_inference_steps}
+```
+
+设置 `batch_size=4` ,然后看一看我们消耗了多少内存:
+
+```python
+from diffusers.utils import make_image_grid
+
+images = pipeline(**get_inputs(batch_size=4)).images
+make_image_grid(images, 2, 2)
+```
+
+除非你有一个更大内存的GPU, 否则上述代码会返回 `OOM` 错误! 大部分内存被 cross-attention 层使用。按顺序运行可以节省大量内存,而不是在批处理中进行。你可以为 pipeline 配置 [`~DiffusionPipeline.enable_attention_slicing`] 函数:
+
+```python
+pipeline.enable_attention_slicing()
+```
+
+现在尝试把 `batch_size` 增加到 8!
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+以前你不能一批生成 4 张图片,而现在你可以在一张图片里面生成八张图片而只需要大概3.5秒!这可能是 T4 GPU 在不牺牲质量的情况运行速度最快的一种方法。
+
+## 质量
+
+在最后两节中, 你要学习如何通过 `fp16` 来优化 pipeline 的速度, 通过使用性能更高的调度器来减少推理步数, 使用注意力切片(*enabling attention slicing*)方法来节省内存。现在,你将关注的是如何提高图像的质量。
+
+### 更好的 checkpoints
+
+有个显而易见的方法是使用更好的 checkpoints。 Stable Diffusion 模型是一个很好的起点, 自正式发布以来,还发布了几个改进版本。然而, 使用更新的版本并不意味着你会得到更好的结果。你仍然需要尝试不同的 checkpoints ,并做一些研究 (例如使用 [negative prompts](https://minimaxir.com/2022/11/stable-diffusion-negative-prompt/)) 来获得更好的结果。
+
+随着该领域的发展, 有越来越多经过微调的高质量的 checkpoints 用来生成不一样的风格. 在 [Hub](https://huggingface.co/models?library=diffusers&sort=downloads) 和 [Diffusers Gallery](https://huggingface.co/spaces/huggingface-projects/diffusers-gallery) 寻找你感兴趣的一种!
+
+### 更好的 pipeline 组件
+
+也可以尝试用新版本替换当前 pipeline 组件。让我们加载最新的 [autodecoder](https://huggingface.co/stabilityai/stable-diffusion-2-1/tree/main/vae) 从 Stability AI 加载到 pipeline, 并生成一些图像:
+
+```python
+from diffusers import AutoencoderKL
+
+vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16).to("cuda")
+pipeline.vae = vae
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+### 更好的提示词工程
+
+用于生成图像的文本非常重要, 因此被称为 *提示词工程*。 在设计提示词工程应注意如下事项:
+
+- 我想生成的图像或类似图像如何存储在互联网上?
+- 我可以提供哪些额外的细节来引导模型朝着我想要的风格生成?
+
+考虑到这一点,让我们改进提示词,以包含颜色和更高质量的细节:
+
+```python
+prompt += ", tribal panther make up, blue on red, side profile, looking away, serious eyes"
+prompt += " 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta"
+```
+
+使用新的提示词生成一批图像:
+
+```python
+images = pipeline(**get_inputs(batch_size=8)).images
+make_image_grid(images, rows=2, cols=4)
+```
+
+
+

+
+
+非常的令人印象深刻! Let's tweak the second image - 把 `Generator` 的种子设置为 `1` - 添加一些关于年龄的主题文本:
+
+```python
+prompts = [
+ "portrait photo of the oldest warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+ "portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta",
+]
+
+generator = [torch.Generator("cuda").manual_seed(1) for _ in range(len(prompts))]
+images = pipeline(prompt=prompts, generator=generator, num_inference_steps=25).images
+make_image_grid(images, 2, 2)
+```
+
+
+

+
+
+## 最后
+
+在本教程中, 您学习了如何优化[`DiffusionPipeline`]以提高计算和内存效率,以及提高生成输出的质量. 如果你有兴趣让你的 pipeline 更快, 可以看一看以下资源:
+
+- 学习 [PyTorch 2.0](./optimization/torch2.0) 和 [`torch.compile`](https://pytorch.org/docs/stable/generated/torch.compile.html) 可以让推理速度提高 5 - 300% . 在 A100 GPU 上, 推理速度可以提高 50% !
+- 如果你没法用 PyTorch 2, 我们建议你安装 [xFormers](./optimization/xformers)。它的内存高效注意力机制(*memory-efficient attention mechanism*)与PyTorch 1.13.1配合使用,速度更快,内存消耗更少。
+- 其他的优化技术, 如:模型卸载(*model offloading*), 包含在 [这份指南](./optimization/fp16).
diff --git a/docs/source/zh/training/adapt_a_model.md b/docs/source/zh/training/adapt_a_model.md
index b5f9155697..7dbf46ec12 100644
--- a/docs/source/zh/training/adapt_a_model.md
+++ b/docs/source/zh/training/adapt_a_model.md
@@ -16,12 +16,12 @@ pipeline.unet.config["in_channels"]
4
```
-而图像修复任务需要输入样本具有9个通道。您可以在 [`runwayml/stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数:
+而图像修复任务需要输入样本具有9个通道。您可以在 [`stable-diffusion-v1-5/stable-diffusion-inpainting`](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting) 这样的预训练修复模型中验证此参数:
```python
from diffusers import StableDiffusionPipeline
-pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", use_safetensors=True)
+pipeline = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-inpainting", use_safetensors=True)
pipeline.unet.config["in_channels"]
9
```
diff --git a/docs/source/zh/training/controlnet.md b/docs/source/zh/training/controlnet.md
index e943177ced..84bc3263a8 100644
--- a/docs/source/zh/training/controlnet.md
+++ b/docs/source/zh/training/controlnet.md
@@ -68,11 +68,8 @@ pip install -r requirements_flax.txt
-
-
-🤗 Accelerate 是一个支持多GPU/TPU训练和混合精度的库,它能根据硬件环境自动配置训练方案。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个支持多GPU/TPU训练和混合精度的库,它能根据硬件环境自动配置训练方案。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
初始化🤗 Accelerate环境:
@@ -96,11 +93,8 @@ write_basic_config()
最后,如需训练自定义数据集,请参阅 [创建训练数据集](create_dataset) 指南了解数据准备方法。
-
-
-下文重点解析脚本中的关键模块,但不会覆盖所有实现细节。如需深入了解,建议直接阅读 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py),如有疑问欢迎反馈。
-
-
+> [!TIP]
+> 下文重点解析脚本中的关键模块,但不会覆盖所有实现细节。如需深入了解,建议直接阅读 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py),如有疑问欢迎反馈。
## 脚本参数
@@ -135,11 +129,8 @@ accelerate launch train_controlnet.py \
脚本中的 [`make_train_dataset`](https://github.com/huggingface/diffusers/blob/64603389da01082055a901f2883c4810d1144edb/examples/controlnet/train_controlnet.py#L582) 函数负责数据预处理,除常规的文本标注分词和图像变换外,还包含条件图像的特效处理:
-
-
-在TPU上流式加载数据集时,🤗 Datasets库可能成为性能瓶颈(因其未针对图像数据优化)。建议考虑 [WebDataset](https://webdataset.github.io/webdataset/)、[TorchData](https://github.com/pytorch/data) 或 [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) 等高效数据格式。
-
-
+> [!TIP]
+> 在TPU上流式加载数据集时,🤗 Datasets库可能成为性能瓶颈(因其未针对图像数据优化)。建议考虑 [WebDataset](https://webdataset.github.io/webdataset/)、[TorchData](https://github.com/pytorch/data) 或 [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds) 等高效数据格式。
```py
conditioning_image_transforms = transforms.Compose(
@@ -304,11 +295,8 @@ tensorboard --logdir runs/fill-circle-100steps-20230411_165612/
在 [http://localhost:6006/#profile](http://localhost:6006/#profile) 查看分析结果。
-
-
-若遇到插件版本冲突,建议重新安装TensorFlow和Tensorboard。注意性能分析插件仍处实验阶段,部分视图可能不完整。`trace_viewer` 会截断超过1M的事件记录,在编译步骤分析时可能导致设备轨迹丢失。
-
-
+> [!WARNING]
+> 若遇到插件版本冲突,建议重新安装TensorFlow和Tensorboard。注意性能分析插件仍处实验阶段,部分视图可能不完整。`trace_viewer` 会截断超过1M的事件记录,在编译步骤分析时可能导致设备轨迹丢失。
```bash
python3 train_controlnet_flax.py \
diff --git a/docs/source/zh/training/distributed_inference.md b/docs/source/zh/training/distributed_inference.md
index e0537735b2..60297371d6 100644
--- a/docs/source/zh/training/distributed_inference.md
+++ b/docs/source/zh/training/distributed_inference.md
@@ -43,11 +43,8 @@ with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
accelerate launch run_distributed.py --num_processes=2
```
-
-
-参考这个最小示例 [脚本](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) 以在多个 GPU 上运行推理。要了解更多信息,请查看 [使用 🤗 Accelerate 进行分布式推理](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 指南。
-
-
+> [!TIP]
+> 参考这个最小示例 [脚本](https://gist.github.com/sayakpaul/cfaebd221820d7b43fae638b4dfa01ba) 以在多个 GPU 上运行推理。要了解更多信息,请查看 [使用 🤗 Accelerate 进行分布式推理](https://huggingface.co/docs/accelerate/en/usage_guides/distributed_inference#distributed-inference-with-accelerate) 指南。
## PyTorch Distributed
diff --git a/docs/source/zh/training/dreambooth.md b/docs/source/zh/training/dreambooth.md
index 493c5385ff..cae5e30be0 100644
--- a/docs/source/zh/training/dreambooth.md
+++ b/docs/source/zh/training/dreambooth.md
@@ -44,11 +44,8 @@ pip install -r requirements_flax.txt
-
-
-🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
初始化 🤗 Accelerate 环境:
@@ -73,19 +70,13 @@ write_basic_config()
最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与
训练脚本。
-
-
-以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读[脚本](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py),并告诉我们如果您有任何问题或疑虑。
-
-
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读[脚本](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py),并告诉我们如果您有任何问题或疑虑。
## 脚本参数
-
-
-DreamBooth 对训练超参数非常敏感,容易过拟合。阅读 [使用 🧨 Diffusers 训练 Stable Diffusion 与 Dreambooth](https://huggingface.co/blog/dreambooth) 博客文章,了解针对不同主题的推荐设置,以帮助您选择合适的超参数。
-
-
+> [!WARNING]
+> DreamBooth 对训练超参数非常敏感,容易过拟合。阅读 [使用 🧨 Diffusers 训练 Stable Diffusion 与 Dreambooth](https://huggingface.co/blog/dreambooth) 博客文章,了解针对不同主题的推荐设置,以帮助您选择合适的超参数。
训练脚本提供了许多参数来自定义您的训练运行。所有参数及其描述都可以在 [`parse_args()`](https://github.com/huggingface/diffusers/blob/072e00897a7cf4302c347a63ec917b4b8add16d4/examples/dreambooth/train_dreambooth.py#L228) 函数中找到。参数设置了默认值,这些默认值应该开箱即用效果不错,但如果您愿意,也可以在训练命令中设置自己的值。
@@ -359,29 +350,26 @@ python train_dreambooth_flax.py \
训练完成后,您可以使用新训练的模型进行推理!
-
-
-等不及在训练完成前就尝试您的模型进行推理?🤭 请确保安装了最新版本的 🤗 Accelerate。
-
-```py
-from diffusers import DiffusionPipeline, UNet2DConditionModel
-from transformers import CLIPTextModel
-import torch
-
-unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
-
-# 如果您使用了 `--args.train_text_encoder` 进行训练,请确保也加载文本编码器
-text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
-).to("cuda")
-
-image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
-image.save("dog-bucket.png")
-```
-
-
+> [!TIP]
+> 等不及在训练完成前就尝试您的模型进行推理?🤭 请确保安装了最新版本的 🤗 Accelerate。
+>
+> ```py
+> from diffusers import DiffusionPipeline, UNet2DConditionModel
+> from transformers import CLIPTextModel
+> import torch
+>
+> unet = UNet2DConditionModel.from_pretrained("path/to/model/checkpoint-100/unet")
+>
+> # 如果您使用了 `--args.train_text_encoder` 进行训练,请确保也加载文本编码器
+> text_encoder = CLIPTextModel.from_pretrained("path/to/model/checkpoint-100/checkpoint-100/text_encoder")
+>
+> pipeline = DiffusionPipeline.from_pretrained(
+> "stable-diffusion-v1-5/stable-diffusion-v1-5", unet=unet, text_encoder=text_encoder, dtype=torch.float16,
+> ).to("cuda")
+>
+> image = pipeline("A photo of sks dog in a bucket", num_inference_steps=50, guidance_scale=7.5).images[0]
+> image.save("dog-bucket.png")
+> ```
diff --git a/docs/source/zh/training/instructpix2pix.md b/docs/source/zh/training/instructpix2pix.md
index b1b616366a..1f9f4eb21e 100644
--- a/docs/source/zh/training/instructpix2pix.md
+++ b/docs/source/zh/training/instructpix2pix.md
@@ -31,11 +31,8 @@ cd examples/instruct_pix2pix
pip install -r requirements.txt
```
-
-
-🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它将根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速导览](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个库,用于帮助您在多个 GPU/TPU 上或使用混合精度进行训练。它将根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速导览](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
初始化一个 🤗 Accelerate 环境:
@@ -59,11 +56,8 @@ write_basic_config()
最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
-
-
-以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py),并告诉我们如果您有任何问题或疑虑。
-
-
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py),并告诉我们如果您有任何问题或疑虑。
## 脚本参数
@@ -176,15 +170,12 @@ if args.conditioning_dropout_prob is not None:
将 `MODEL_NAME` 环境变量设置为模型名称(可以是 Hub 上的模型 ID 或本地模型的路径),并将 `DATASET_ID` 设置为 Hub 上数据集的名称。脚本会创建并保存所有组件(特征提取器、调度器、文本编码器、UNet 等)到您的仓库中的一个子文件夹。
-
-
-为了获得更好的结果,尝试使用更大的数据集进行更长时间的训练。我们只在较小规模的数据集上测试过此训练脚本。
-
-
-
-要使用 Weights and Biases 监控训练进度,请将 `--report_to=wandb` 参数添加到训练命令中,并使用 `--val_image_url` 指定验证图像,使用 `--validation_prompt` 指定验证提示。这对于调试模型非常有用。
-
-
+> [!TIP]
+> 为了获得更好的结果,尝试使用更大的数据集进行更长时间的训练。我们只在较小规模的数据集上测试过此训练脚本。
+>
+>
+>
+> 要使用 Weights and Biases 监控训练进度,请将 `--report_to=wandb` 参数添加到训练命令中,并使用 `--val_image_url` 指定验证图像,使用 `--validation_prompt` 指定验证提示。这对于调试模型非常有用。
如果您在多个 GPU 上训练,请将 `--multi_gpu` 参数添加到 `accelerate launch` 命令中。
diff --git a/docs/source/zh/training/kandinsky.md b/docs/source/zh/training/kandinsky.md
index 8da5c0c3a0..8ef3524ee7 100644
--- a/docs/source/zh/training/kandinsky.md
+++ b/docs/source/zh/training/kandinsky.md
@@ -9,11 +9,8 @@ http://www.apache.org/licenses/LICENSE-2.0
# Kandinsky 2.2
-
-
-此脚本是实验性的,容易过拟合并遇到灾难性遗忘等问题。尝试探索不同的超参数以在您的数据集上获得最佳结果。
-
-
+> [!WARNING]
+> 此脚本是实验性的,容易过拟合并遇到灾难性遗忘等问题。尝试探索不同的超参数以在您的数据集上获得最佳结果。
Kandinsky 2.2 是一个多语言文本到图像模型,能够生成更逼真的图像。该模型包括一个图像先验模型,用于从文本提示创建图像嵌入,以及一个解码器模型,基于先验模型的嵌入生成图像。这就是为什么在 Diffusers 中您会找到两个独立的脚本用于 Kandinsky 2.2,一个用于训练先验模型,另一个用于训练解码器模型。您可以分别训练这两个模型,但为了获得最佳结果,您应该同时训练先验和解码器模型。
@@ -36,12 +33,9 @@ cd examples/kandinsky2_2/text_to_image
pip install -r requirements.txt
```
-
-
-🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate 的 [快速入门](https://huggingface.co/docs/accelerate/quicktour
-) 了解更多。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate 的 [快速入门](https://huggingface.co/docs/accelerate/quicktour
+> ) 了解更多。
初始化一个 🤗 Accelerate 环境:
@@ -65,11 +59,8 @@ write_basic_config()
最后,如果您想在自己的数据集上训练模型,请查看 [创建用于训练的数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
-
-
-以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读脚本,并让我们知道您有任何疑问或顾虑。
-
-
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未详细涵盖脚本的每个方面。如果您有兴趣了解更多,请随时阅读脚本,并让我们知道您有任何疑问或顾虑。
## 脚本参数
@@ -209,12 +200,9 @@ model_pred = unet(noisy_latents, timesteps, None, added_cond_kwargs=added_cond_k
如果您在多个GPU上训练,请在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。
-
-
-要使用Weights & Biases监控训练进度,请在训练命令中添加 `--report_to=wandb` 参数。您还需要
-建议在训练命令中添加 `--validation_prompt` 以跟踪结果。这对于调试模型和查看中间结果非常有用。
-
-
+> [!TIP]
+> 要使用Weights & Biases监控训练进度,请在训练命令中添加 `--report_to=wandb` 参数。您还需要
+> 建议在训练命令中添加 `--validation_prompt` 以跟踪结果。这对于调试模型和查看中间结果非常有用。
@@ -284,11 +272,8 @@ prompt="A robot naruto, 4k photo"
image = pipeline(prompt=prompt, negative_prompt=negative_prompt).images[0]
```
-
-
-可以随意将 `kandinsky-community/kandinsky-2-2-decoder` 替换为您自己训练的 decoder 检查点!
-
-
+> [!TIP]
+> 可以随意将 `kandinsky-community/kandinsky-2-2-decoder` 替换为您自己训练的 decoder 检查点!
diff --git a/docs/source/zh/training/lora.md b/docs/source/zh/training/lora.md
index a7b7abb32d..ce29365450 100644
--- a/docs/source/zh/training/lora.md
+++ b/docs/source/zh/training/lora.md
@@ -12,19 +12,13 @@ specific language governing permissions and limitations under the License.
# LoRA 低秩适配
-
-
-当前功能处于实验阶段,API可能在未来版本中变更。
-
-
+> [!WARNING]
+> 当前功能处于实验阶段,API可能在未来版本中变更。
[LoRA(大语言模型的低秩适配)](https://hf.co/papers/2106.09685) 是一种轻量级训练技术,能显著减少可训练参数量。其原理是通过向模型注入少量新权重参数,仅训练这些新增参数。这使得LoRA训练速度更快、内存效率更高,并生成更小的模型权重文件(通常仅数百MB),便于存储和分享。LoRA还可与DreamBooth等其他训练技术结合以加速训练过程。
-
-
-LoRA具有高度通用性,目前已支持以下应用场景:[DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)、[Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py)、[Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py)、[文生图](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)以及[Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py)。
-
-
+> [!TIP]
+> LoRA具有高度通用性,目前已支持以下应用场景:[DreamBooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)、[Kandinsky 2.2](https://github.com/huggingface/diffusers/blob/main/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py)、[Stable Diffusion XL](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py)、[文生图](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)以及[Wuerstchen](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py)。
本指南将通过解析[train_text_to_image_lora.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)脚本,帮助您深入理解其工作原理,并掌握如何针对具体需求进行定制化修改。
@@ -57,11 +51,8 @@ pip install -r requirements_flax.txt
-
-
-🤗 Accelerate是一个支持多GPU/TPU训练和混合精度计算的库,它能根据硬件环境自动配置训练方案。参阅🤗 Accelerate[快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
-
-
+> [!TIP]
+> 🤗 Accelerate是一个支持多GPU/TPU训练和混合精度计算的库,它能根据硬件环境自动配置训练方案。参阅🤗 Accelerate[快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
初始化🤗 Accelerate环境:
@@ -85,11 +76,8 @@ write_basic_config()
如需训练自定义数据集,请参考[创建训练数据集指南](create_dataset)了解数据准备流程。
-
-
-以下章节重点解析训练脚本中与LoRA相关的核心部分,但不会涵盖所有实现细节。如需完整理解,建议直接阅读[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py),如有疑问欢迎反馈。
-
-
+> [!TIP]
+> 以下章节重点解析训练脚本中与LoRA相关的核心部分,但不会涵盖所有实现细节。如需完整理解,建议直接阅读[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py),如有疑问欢迎反馈。
## 脚本参数
@@ -177,11 +165,8 @@ optimizer = optimizer_cls(
多GPU训练请添加`--multi_gpu`参数。
-
-
-在11GB显存的2080 Ti显卡上完整训练约需5小时。
-
-
+> [!WARNING]
+> 在11GB显存的2080 Ti显卡上完整训练约需5小时。
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
diff --git a/docs/source/zh/training/text2image.md b/docs/source/zh/training/text2image.md
index 193b839e9b..4465adbe2a 100644
--- a/docs/source/zh/training/text2image.md
+++ b/docs/source/zh/training/text2image.md
@@ -12,11 +12,8 @@ specific language governing permissions and limitations under the License.
# 文生图
-
-
-文生图训练脚本目前处于实验阶段,容易出现过拟合和灾难性遗忘等问题。建议尝试不同超参数以获得最佳数据集适配效果。
-
-
+> [!WARNING]
+> 文生图训练脚本目前处于实验阶段,容易出现过拟合和灾难性遗忘等问题。建议尝试不同超参数以获得最佳数据集适配效果。
Stable Diffusion 等文生图模型能够根据文本提示生成对应图像。
@@ -49,11 +46,8 @@ pip install -r requirements_flax.txt
-
-
-🤗 Accelerate 是支持多GPU/TPU训练和混合精度的工具库,能根据硬件环境自动配置训练参数。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
-
-
+> [!TIP]
+> 🤗 Accelerate 是支持多GPU/TPU训练和混合精度的工具库,能根据硬件环境自动配置训练参数。参阅 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 了解更多。
初始化 🤗 Accelerate 环境:
@@ -79,11 +73,8 @@ write_basic_config()
## 脚本参数
-
-
-以下重点介绍脚本中影响训练效果的关键参数,如需完整参数说明可查阅 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)。如有疑问欢迎反馈。
-
-
+> [!TIP]
+> 以下重点介绍脚本中影响训练效果的关键参数,如需完整参数说明可查阅 [脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)。如有疑问欢迎反馈。
训练脚本提供丰富参数供自定义训练流程,所有参数及说明详见 [`parse_args()`](https://github.com/huggingface/diffusers/blob/8959c5b9dec1c94d6ba482c94a58d2215c5fd026/examples/text_to_image/train_text_to_image.py#L193) 函数。该函数为每个参数提供默认值(如批次大小、学习率等),也可通过命令行参数覆盖。
@@ -160,11 +151,8 @@ def preprocess_train(examples):
以 [火影忍者BLIP标注数据集](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions) 为例训练生成火影角色。设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集(Hub或本地路径)。多GPU训练需在 `accelerate launch` 命令中添加 `--multi_gpu` 参数。
-
-
-使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
-
-
+> [!TIP]
+> 使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
@@ -194,11 +182,8 @@ Flax训练方案在TPU/GPU上效率更高(由 [@duongna211](https://github.com
设置环境变量 `MODEL_NAME` 和 `dataset_name` 指定模型和数据集(Hub或本地路径)。
-
-
-使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
-
-
+> [!TIP]
+> 使用本地数据集时,设置 `TRAIN_DIR` 和 `OUTPUT_DIR` 环境变量为数据集路径和模型保存路径。
```bash
export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5"
diff --git a/docs/source/zh/training/text_inversion.md b/docs/source/zh/training/text_inversion.md
index 2945699c61..eda9f91144 100644
--- a/docs/source/zh/training/text_inversion.md
+++ b/docs/source/zh/training/text_inversion.md
@@ -45,11 +45,8 @@ pip install -r requirements_flax.txt
-
-
-🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多GPU/TPU或混合精度环境下训练的工具库。它会根据硬件和环境自动配置训练设置。查看🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour)了解更多。
初始化🤗 Accelerate环境:
@@ -73,11 +70,8 @@ write_basic_config()
最后,如果想在自定义数据集上训练模型,请参阅[创建训练数据集](create_dataset)指南,了解如何创建适用于训练脚本的数据集。
-
-
-以下部分重点介绍训练脚本中需要理解的关键修改点,但未涵盖脚本所有细节。如需深入了解,可随时查阅[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py),如有疑问欢迎反馈。
-
-
+> [!TIP]
+> 以下部分重点介绍训练脚本中需要理解的关键修改点,但未涵盖脚本所有细节。如需深入了解,可随时查阅[脚本源码](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py),如有疑问欢迎反馈。
## 脚本参数
@@ -173,11 +167,8 @@ snapshot_download(
- `token_identifier.txt`:特殊占位符词汇
- `type_of_concept.txt`:训练概念类型("object"或"style")
-
-
-在单块V100 GPU上完整训练约需1小时。
-
-
+> [!WARNING]
+> 在单块V100 GPU上完整训练约需1小时。
启动脚本前还有最后一步。如果想实时观察训练过程,可以定期保存生成图像。在训练命令中添加以下参数:
diff --git a/docs/source/zh/training/wuerstchen.md b/docs/source/zh/training/wuerstchen.md
index 8a6abe6624..c80cc944a3 100644
--- a/docs/source/zh/training/wuerstchen.md
+++ b/docs/source/zh/training/wuerstchen.md
@@ -33,11 +33,8 @@ cd examples/wuerstchen/text_to_image
pip install -r requirements.txt
```
-
-
-🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
-
-
+> [!TIP]
+> 🤗 Accelerate 是一个帮助您在多个 GPU/TPU 上或使用混合精度进行训练的库。它会根据您的硬件和环境自动配置训练设置。查看 🤗 Accelerate [快速入门](https://huggingface.co/docs/accelerate/quicktour) 以了解更多信息。
初始化一个 🤗 Accelerate 环境:
@@ -61,11 +58,8 @@ write_basic_config()
最后,如果您想在自己的数据集上训练模型,请查看 [创建训练数据集](create_dataset) 指南,了解如何创建与训练脚本兼容的数据集。
-
-
-以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未涵盖 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 的详细信息。如果您有兴趣了解更多,请随时阅读脚本,并告诉我们您是否有任何问题或疑虑。
-
-
+> [!TIP]
+> 以下部分重点介绍了训练脚本中对于理解如何修改它很重要的部分,但并未涵盖 [脚本](https://github.com/huggingface/diffusers/blob/main/examples/wuerstchen/text_to_image/train_text_to_image_prior.py) 的详细信息。如果您有兴趣了解更多,请随时阅读脚本,并告诉我们您是否有任何问题或疑虑。
## 脚本参数
@@ -134,11 +128,8 @@ pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
设置`DATASET_NAME`环境变量为Hub中的数据集名称。本指南使用[Naruto BLIP captions](https://huggingface.co/datasets/lambdalabs/naruto-blip-captions)数据集,但您也可以创建和训练自己的数据集(参见[创建用于训练的数据集](create_dataset)指南)。
-
-
-要使用Weights & Biases监控训练进度,请在训练命令中添加`--report_to=wandb`参数。您还需要在训练命令中添加`--validation_prompt`以跟踪结果。这对于调试模型和查看中间结果非常有用。
-
-
+> [!TIP]
+> 要使用Weights & Biases监控训练进度,请在训练命令中添加`--report_to=wandb`参数。您还需要在训练命令中添加`--validation_prompt`以跟踪结果。这对于调试模型和查看中间结果非常有用。
```bash
export DATASET_NAME="lambdalabs/naruto-blip-captions"
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index a46490e8b3..5aa33190d4 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -25,6 +25,10 @@
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
diff --git a/examples/community/README.md b/examples/community/README.md
index e314463077..4a4b0f5fd9 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -1328,7 +1328,7 @@ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined"
# Load Stable Diffusion Inpainting Pipeline with custom pipeline
pipe = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
+ "stable-diffusion-v1-5/stable-diffusion-inpainting",
custom_pipeline="text_inpainting",
segmentation_model=model,
segmentation_processor=processor
diff --git a/examples/community/adaptive_mask_inpainting.py b/examples/community/adaptive_mask_inpainting.py
index aac460cb46..da67debe72 100644
--- a/examples/community/adaptive_mask_inpainting.py
+++ b/examples/community/adaptive_mask_inpainting.py
@@ -126,7 +126,7 @@ EXAMPLE_DOC_STRING = """
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
@@ -347,7 +347,7 @@ class AdaptiveMaskInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -429,8 +429,8 @@ class AdaptiveMaskInpaintPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -970,7 +970,7 @@ class AdaptiveMaskInpaintPipeline(
>>> default_mask_image = download_image(mask_url).resize((512, 512))
>>> pipe = AdaptiveMaskInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-inpainting", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -1095,7 +1095,7 @@ class AdaptiveMaskInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py
index a7c540ceb9..a7c610ad43 100644
--- a/examples/community/composable_stable_diffusion.py
+++ b/examples/community/composable_stable_diffusion.py
@@ -62,7 +62,7 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -145,8 +145,8 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin)
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/fresco_v2v.py b/examples/community/fresco_v2v.py
index 47ba71299d..b79834db5e 100644
--- a/examples/community/fresco_v2v.py
+++ b/examples/community/fresco_v2v.py
@@ -1276,7 +1276,7 @@ class FrescoV2VPipeline(StableDiffusionControlNetImg2ImgPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/community/hd_painter.py b/examples/community/hd_painter.py
index 20bb43a76f..70e5656855 100644
--- a/examples/community/hd_painter.py
+++ b/examples/community/hd_painter.py
@@ -678,7 +678,7 @@ class StableDiffusionHDPainterPipeline(StableDiffusionInpaintPipeline):
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/img2img_inpainting.py b/examples/community/img2img_inpainting.py
index 499230b1e2..595df107ca 100644
--- a/examples/community/img2img_inpainting.py
+++ b/examples/community/img2img_inpainting.py
@@ -78,7 +78,7 @@ class ImageToImageInpaintingPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/instaflow_one_step.py b/examples/community/instaflow_one_step.py
index 06be1d10b6..0f16707ead 100644
--- a/examples/community/instaflow_one_step.py
+++ b/examples/community/instaflow_one_step.py
@@ -86,7 +86,7 @@ class InstaFlowPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -165,8 +165,8 @@ class InstaFlowPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py
index 5b420882e9..d16aaf5a54 100644
--- a/examples/community/ip_adapter_face_id.py
+++ b/examples/community/ip_adapter_face_id.py
@@ -166,7 +166,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -247,8 +247,8 @@ class IPAdapterFaceIDStableDiffusionPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/kohya_hires_fix.py b/examples/community/kohya_hires_fix.py
index 63f6b8973c..c968ecf2af 100644
--- a/examples/community/kohya_hires_fix.py
+++ b/examples/community/kohya_hires_fix.py
@@ -414,7 +414,7 @@ class StableDiffusionHighResFixPipeline(StableDiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/community/latent_consistency_interpolate.py b/examples/community/latent_consistency_interpolate.py
index 9fc4233682..e8349ba317 100644
--- a/examples/community/latent_consistency_interpolate.py
+++ b/examples/community/latent_consistency_interpolate.py
@@ -222,7 +222,7 @@ class LatentConsistencyModelWalkPipeline(
supports [`LCMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/community/llm_grounded_diffusion.py b/examples/community/llm_grounded_diffusion.py
index 8f04761502..5bf6674a43 100644
--- a/examples/community/llm_grounded_diffusion.py
+++ b/examples/community/llm_grounded_diffusion.py
@@ -302,7 +302,7 @@ class LLMGroundedDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -392,8 +392,8 @@ class LLMGroundedDiffusionPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/lpw_stable_diffusion.py b/examples/community/lpw_stable_diffusion.py
index cb017c0bbe..58e932bbcf 100644
--- a/examples/community/lpw_stable_diffusion.py
+++ b/examples/community/lpw_stable_diffusion.py
@@ -552,8 +552,8 @@ class StableDiffusionLongPromptWeightingPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/lpw_stable_diffusion_xl.py b/examples/community/lpw_stable_diffusion_xl.py
index 272c5d5652..95d3405df5 100644
--- a/examples/community/lpw_stable_diffusion_xl.py
+++ b/examples/community/lpw_stable_diffusion_xl.py
@@ -1765,7 +1765,7 @@ class SDXLLongPromptWeightingPipeline(
# Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
diff --git a/examples/community/matryoshka.py b/examples/community/matryoshka.py
index 274851e2ac..97ad8b9e86 100644
--- a/examples/community/matryoshka.py
+++ b/examples/community/matryoshka.py
@@ -1475,11 +1475,8 @@ class MatryoshkaFusedAttnProcessor2_0:
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is currently 🧪 experimental in nature and can change in future.
-
-
+ > [!WARNING]
+ > This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
@@ -2696,11 +2693,8 @@ class MatryoshkaUNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -2719,11 +2713,8 @@ class MatryoshkaUNet2DConditionModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -3738,8 +3729,8 @@ class MatryoshkaPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/multilingual_stable_diffusion.py b/examples/community/multilingual_stable_diffusion.py
index afef4e9e97..436803f201 100644
--- a/examples/community/multilingual_stable_diffusion.py
+++ b/examples/community/multilingual_stable_diffusion.py
@@ -78,7 +78,7 @@ class MultilingualStableDiffusion(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/pipeline_controlnet_xl_kolors_inpaint.py b/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
index 4b6123cc1f..3abd984829 100644
--- a/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
+++ b/examples/community/pipeline_controlnet_xl_kolors_inpaint.py
@@ -1607,7 +1607,7 @@ class KolorsControlNetInpaintPipeline(
# 9. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/pipeline_fabric.py b/examples/community/pipeline_fabric.py
index dcc7730cbe..d29e98df5e 100644
--- a/examples/community/pipeline_fabric.py
+++ b/examples/community/pipeline_fabric.py
@@ -135,7 +135,7 @@ class FabricPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
"""
@@ -163,8 +163,8 @@ class FabricPipeline(DiffusionPipeline):
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/pipeline_kolors_inpainting.py b/examples/community/pipeline_kolors_inpainting.py
index 3cab8ecac0..26517819eb 100644
--- a/examples/community/pipeline_kolors_inpainting.py
+++ b/examples/community/pipeline_kolors_inpainting.py
@@ -1487,7 +1487,7 @@ class KolorsInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/pipeline_prompt2prompt.py b/examples/community/pipeline_prompt2prompt.py
index 8d94dc9248..eb19667970 100644
--- a/examples/community/pipeline_prompt2prompt.py
+++ b/examples/community/pipeline_prompt2prompt.py
@@ -106,7 +106,7 @@ class Prompt2PromptPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -187,8 +187,8 @@ class Prompt2PromptPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/pipeline_sdxl_style_aligned.py b/examples/community/pipeline_sdxl_style_aligned.py
index 10438af365..51547599f5 100644
--- a/examples/community/pipeline_sdxl_style_aligned.py
+++ b/examples/community/pipeline_sdxl_style_aligned.py
@@ -1730,7 +1730,7 @@ class StyleAlignedSDXLPipeline(
# Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
diff --git a/examples/community/pipeline_stable_diffusion_boxdiff.py b/examples/community/pipeline_stable_diffusion_boxdiff.py
index 1133321fcc..c05a614313 100644
--- a/examples/community/pipeline_stable_diffusion_boxdiff.py
+++ b/examples/community/pipeline_stable_diffusion_boxdiff.py
@@ -59,7 +59,7 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -392,7 +392,7 @@ class StableDiffusionBoxDiffPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -473,8 +473,8 @@ class StableDiffusionBoxDiffPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -948,11 +948,8 @@ class StableDiffusionBoxDiffPipeline(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
@@ -978,11 +975,8 @@ class StableDiffusionBoxDiffPipeline(
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
diff --git a/examples/community/pipeline_stable_diffusion_pag.py b/examples/community/pipeline_stable_diffusion_pag.py
index 6728e2a60b..3f98dca0b9 100644
--- a/examples/community/pipeline_stable_diffusion_pag.py
+++ b/examples/community/pipeline_stable_diffusion_pag.py
@@ -42,7 +42,7 @@ EXAMPLE_DOC_STRING = """
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt).images[0]
@@ -359,7 +359,7 @@ class StableDiffusionPAGPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
@@ -440,8 +440,8 @@ class StableDiffusionPAGPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -940,9 +940,8 @@ class StableDiffusionPAGPipeline(
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
-
- This API is 🧪 experimental.
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
@@ -966,9 +965,8 @@ class StableDiffusionPAGPipeline(
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.unfuse_qkv_projections
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
- This API is 🧪 experimental.
-
+ > [!WARNING]
+ > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
diff --git a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
index 9777633535..e358f66b4a 100644
--- a/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
+++ b/examples/community/pipeline_stable_diffusion_upscale_ldm3d.py
@@ -100,7 +100,7 @@ class StableDiffusionUpscaleLDM3DPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index a881814c2a..65df4c03eb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -2042,7 +2042,7 @@ class StableDiffusionXL_AE_Pipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
index 564a19e923..6dade126f2 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter.py
@@ -188,7 +188,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
index c73433b20f..9ec6a90b07 100644
--- a/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
+++ b/examples/community/pipeline_stable_diffusion_xl_controlnet_adapter_inpaint.py
@@ -330,7 +330,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
@@ -1569,7 +1569,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/examples/community/pipeline_zero1to3.py b/examples/community/pipeline_zero1to3.py
index 9e29566978..1be59fd832 100644
--- a/examples/community/pipeline_zero1to3.py
+++ b/examples/community/pipeline_zero1to3.py
@@ -46,7 +46,7 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> prompt = "a photo of an astronaut riding a horse on mars"
@@ -86,7 +86,7 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
cc_projection ([`CCProjection`]):
@@ -164,8 +164,8 @@ class Zero1to3StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py
index 78a15a03b0..840f9e206d 100644
--- a/examples/community/rerender_a_video.py
+++ b/examples/community/rerender_a_video.py
@@ -288,7 +288,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/run_onnx_controlnet.py b/examples/community/run_onnx_controlnet.py
index f0ab2a2b96..2b56e8a1e5 100644
--- a/examples/community/run_onnx_controlnet.py
+++ b/examples/community/run_onnx_controlnet.py
@@ -54,7 +54,7 @@ EXAMPLE_DOC_STRING = """
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
diff --git a/examples/community/run_tensorrt_controlnet.py b/examples/community/run_tensorrt_controlnet.py
index e4f1abc83b..b62eb4f58e 100644
--- a/examples/community/run_tensorrt_controlnet.py
+++ b/examples/community/run_tensorrt_controlnet.py
@@ -158,7 +158,7 @@ EXAMPLE_DOC_STRING = """
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
diff --git a/examples/community/sd_text2img_k_diffusion.py b/examples/community/sd_text2img_k_diffusion.py
index 4d5cea497f..e351420f78 100755
--- a/examples/community/sd_text2img_k_diffusion.py
+++ b/examples/community/sd_text2img_k_diffusion.py
@@ -64,7 +64,7 @@ class StableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/sde_drag.py b/examples/community/sde_drag.py
index f408ee64db..63899ce738 100644
--- a/examples/community/sde_drag.py
+++ b/examples/community/sde_drag.py
@@ -114,7 +114,7 @@ class SdeDragPipeline(DiffusionPipeline):
>>> from diffusers import DDIMScheduler, DiffusionPipeline
>>> # Load the pipeline
- >>> model_path = "runwayml/stable-diffusion-v1-5"
+ >>> model_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
>>> scheduler = DDIMScheduler.from_pretrained(model_path, subfolder="scheduler")
>>> pipe = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, custom_pipeline="sde_drag")
>>> pipe.to('cuda')
diff --git a/examples/community/stable_diffusion_comparison.py b/examples/community/stable_diffusion_comparison.py
index 22f3b3e0c3..ce6e77c87f 100644
--- a/examples/community/stable_diffusion_comparison.py
+++ b/examples/community/stable_diffusion_comparison.py
@@ -46,7 +46,7 @@ class StableDiffusionComparisonPipeline(DiffusionPipeline, StableDiffusionMixin)
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/stable_diffusion_controlnet_img2img.py b/examples/community/stable_diffusion_controlnet_img2img.py
index 6d8038cfd4..aa116112be 100644
--- a/examples/community/stable_diffusion_controlnet_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_img2img.py
@@ -36,7 +36,7 @@ EXAMPLE_DOC_STRING = """
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe_controlnet = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
controlnet=controlnet,
safety_checker=None,
torch_dtype=torch.float16
diff --git a/examples/community/stable_diffusion_controlnet_inpaint.py b/examples/community/stable_diffusion_controlnet_inpaint.py
index fe7b808b6b..6d710e0d73 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint.py
@@ -81,7 +81,7 @@ EXAMPLE_DOC_STRING = """
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
diff --git a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
index 2b5dc77fe5..fcb5ed059b 100644
--- a/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
+++ b/examples/community/stable_diffusion_controlnet_inpaint_img2img.py
@@ -80,7 +80,7 @@ EXAMPLE_DOC_STRING = """
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
+ "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, safety_checker=None, torch_dtype=torch.float16
)
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
diff --git a/examples/community/stable_diffusion_controlnet_reference.py b/examples/community/stable_diffusion_controlnet_reference.py
index e5dd249e04..74c81b6362 100644
--- a/examples/community/stable_diffusion_controlnet_reference.py
+++ b/examples/community/stable_diffusion_controlnet_reference.py
@@ -37,7 +37,7 @@ EXAMPLE_DOC_STRING = """
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetReferencePipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
controlnet=controlnet,
safety_checker=None,
torch_dtype=torch.float16
diff --git a/examples/community/stable_diffusion_ipex.py b/examples/community/stable_diffusion_ipex.py
index 7d1cd4f5d0..4f545aa09d 100644
--- a/examples/community/stable_diffusion_ipex.py
+++ b/examples/community/stable_diffusion_ipex.py
@@ -43,7 +43,7 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import StableDiffusionPipeline
- >>> pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
+ >>> pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_ipex")
>>> # For Float32
>>> pipe.prepare_for_ipex(prompt, dtype=torch.float32, height=512, width=512) #value of image height/width should be consistent with the pipeline inference
@@ -85,7 +85,7 @@ class StableDiffusionIPEXPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -161,8 +161,8 @@ class StableDiffusionIPEXPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/stable_diffusion_mega.py b/examples/community/stable_diffusion_mega.py
index 77e5011d2a..c67ebc80b0 100644
--- a/examples/community/stable_diffusion_mega.py
+++ b/examples/community/stable_diffusion_mega.py
@@ -47,7 +47,7 @@ class StableDiffusionMegaPipeline(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionMegaSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py
index 6f7dce9823..d0372bbeba 100644
--- a/examples/community/stable_diffusion_reference.py
+++ b/examples/community/stable_diffusion_reference.py
@@ -46,7 +46,7 @@ EXAMPLE_DOC_STRING = """
>>> input_image = load_image("https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png")
>>> pipe = StableDiffusionReferencePipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5",
+ "stable-diffusion-v1-5/stable-diffusion-v1-5",
safety_checker=None,
torch_dtype=torch.float16
).to('cuda:0')
@@ -112,7 +112,7 @@ class StableDiffusionReferencePipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -194,8 +194,8 @@ class StableDiffusionReferencePipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/stable_diffusion_repaint.py b/examples/community/stable_diffusion_repaint.py
index 94b9f8b01b..b974e3c7ae 100644
--- a/examples/community/stable_diffusion_repaint.py
+++ b/examples/community/stable_diffusion_repaint.py
@@ -167,7 +167,7 @@ class StableDiffusionRepaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -249,8 +249,8 @@ class StableDiffusionRepaintPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/stable_diffusion_tensorrt_img2img.py b/examples/community/stable_diffusion_tensorrt_img2img.py
index dc11703b6a..5b7733fe57 100755
--- a/examples/community/stable_diffusion_tensorrt_img2img.py
+++ b/examples/community/stable_diffusion_tensorrt_img2img.py
@@ -678,7 +678,7 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -766,8 +766,8 @@ class TensorRTStableDiffusionImg2ImgPipeline(DiffusionPipeline):
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py
index fff7309e9c..fc81e4c289 100755
--- a/examples/community/stable_diffusion_tensorrt_inpaint.py
+++ b/examples/community/stable_diffusion_tensorrt_inpaint.py
@@ -682,7 +682,7 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -770,8 +770,8 @@ class TensorRTStableDiffusionInpaintPipeline(DiffusionPipeline):
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py
index 15a6e69c41..e1d09edf93 100755
--- a/examples/community/stable_diffusion_tensorrt_txt2img.py
+++ b/examples/community/stable_diffusion_tensorrt_txt2img.py
@@ -594,7 +594,7 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -682,8 +682,8 @@ class TensorRTStableDiffusionPipeline(DiffusionPipeline):
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/examples/community/text_inpainting.py b/examples/community/text_inpainting.py
index f262cf2cac..bdf9eca498 100644
--- a/examples/community/text_inpainting.py
+++ b/examples/community/text_inpainting.py
@@ -52,7 +52,7 @@ class TextInpainting(DiffusionPipeline, StableDiffusionMixin):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index bd3a974a17..3b6ab814f2 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -25,6 +25,10 @@
# "Jinja2",
# "peft>=0.11.1",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
index 03c05a05e0..fc6df87768 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -14,6 +14,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
import argparse
import copy
import itertools
diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
index feec4da712..56de160d6f 100644
--- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py
+++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py
@@ -13,6 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
+# ]
+# ///
+
import argparse
import copy
import itertools
@@ -1320,7 +1338,7 @@ def main(args):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
- latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if train_dataset.custom_instance_prompts:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, prompt_embeds_mask = compute_text_embeddings(
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index b188a80916..2b0c1ee669 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -25,6 +25,10 @@
# "Jinja2",
# "peft>=0.14.0",
# "sentencepiece",
+# "torchvision",
+# "datasets",
+# "bitsandbytes",
+# "prodigyopt",
# ]
# ///
diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py
index fcce297c37..ee5dced817 100644
--- a/examples/model_search/pipeline_easy.py
+++ b/examples/model_search/pipeline_easy.py
@@ -1246,12 +1246,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1355,12 +1352,9 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1504,12 +1498,9 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1614,12 +1605,9 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login`.
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login`.
Examples:
@@ -1763,12 +1751,9 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login
Examples:
@@ -1872,12 +1857,9 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
below for more information.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `hf auth login
-
-
+ > [!TIP]
+ > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
+ > `hf auth login
Examples:
diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py
index 38f0adb891..7ae6ae57c2 100644
--- a/examples/research_projects/anytext/anytext.py
+++ b/examples/research_projects/anytext/anytext.py
@@ -1223,7 +1223,7 @@ class AnyTextPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/research_projects/dreambooth_inpaint/README.md b/examples/research_projects/dreambooth_inpaint/README.md
index 46703fa982..b6ee1d72f6 100644
--- a/examples/research_projects/dreambooth_inpaint/README.md
+++ b/examples/research_projects/dreambooth_inpaint/README.md
@@ -5,7 +5,7 @@ This script was added by @thedarkzeno .
Please note that this script is not actively maintained, you can open an issue and tag @thedarkzeno or @patil-suraj though.
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
@@ -29,7 +29,7 @@ Prior-preservation is used to avoid overfitting and language-drift. Refer to the
According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases.
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -60,7 +60,7 @@ With the help of gradient checkpointing and the 8-bit optimizer from bitsandbyte
To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation).
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -92,7 +92,7 @@ Pass the `--train_text_encoder` argument to the script to enable training `text_
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
diff --git a/examples/research_projects/ip_adapter/README.md b/examples/research_projects/ip_adapter/README.md
index 3df9644ddf..0bead5ae85 100644
--- a/examples/research_projects/ip_adapter/README.md
+++ b/examples/research_projects/ip_adapter/README.md
@@ -55,7 +55,7 @@ The Accelerate launch command is used to train a model using multiple GPUs and m
```
accelerate launch --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
---pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+--pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
@@ -73,7 +73,7 @@ tutorial_train_ip-adapter.py \
```
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
- --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
+ --pretrained_model_name_or_path="stable-diffusion-v1-5/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
diff --git a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
index 8ddef1b83c..3412de662f 100644
--- a/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
+++ b/examples/research_projects/multi_subject_dreambooth_inpainting/README.md
@@ -27,7 +27,7 @@ You can build multiple datasets for every subject and upload them to the 🤗 hu
Before launching the training script, make sure to select the inpainting the target model, the output directory and the 🤗 datasets.
```bash
-export MODEL_NAME="runwayml/stable-diffusion-inpainting"
+export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-inpainting"
export OUTPUT_DIR="path-to-save-model"
export DATASET_1="gzguevara/mr_potato_head_masked"
diff --git a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
index 1bd9c0161f..233df12765 100644
--- a/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
+++ b/examples/research_projects/promptdiffusion/pipeline_prompt_diffusion.py
@@ -177,7 +177,7 @@ class PromptDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for more details
about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py
index cdc3a54fdf..922cb42615 100644
--- a/examples/research_projects/vae/vae_roundtrip.py
+++ b/examples/research_projects/vae/vae_roundtrip.py
@@ -238,7 +238,7 @@ def parse_args() -> argparse.Namespace:
# EXAMPLE USAGE:
#
-# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
+# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "stable-diffusion-v1-5/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png"
#
# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png"
#
diff --git a/examples/server-async/Pipelines.py b/examples/server-async/Pipelines.py
new file mode 100644
index 0000000000..f89cac6a7e
--- /dev/null
+++ b/examples/server-async/Pipelines.py
@@ -0,0 +1,91 @@
+import logging
+import os
+from dataclasses import dataclass, field
+from typing import List
+
+import torch
+from pydantic import BaseModel
+
+from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
+
+
+logger = logging.getLogger(__name__)
+
+
+class TextToImageInput(BaseModel):
+ model: str
+ prompt: str
+ size: str | None = None
+ n: int | None = None
+
+
+@dataclass
+class PresetModels:
+ SD3: List[str] = field(default_factory=lambda: ["stabilityai/stable-diffusion-3-medium"])
+ SD3_5: List[str] = field(
+ default_factory=lambda: [
+ "stabilityai/stable-diffusion-3.5-large",
+ "stabilityai/stable-diffusion-3.5-large-turbo",
+ "stabilityai/stable-diffusion-3.5-medium",
+ ]
+ )
+
+
+class TextToImagePipelineSD3:
+ def __init__(self, model_path: str | None = None):
+ self.model_path = model_path or os.getenv("MODEL_PATH")
+ self.pipeline: StableDiffusion3Pipeline | None = None
+ self.device: str | None = None
+
+ def start(self):
+ if torch.cuda.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
+ logger.info("Loading CUDA")
+ self.device = "cuda"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.float16,
+ ).to(device=self.device)
+ elif torch.backends.mps.is_available():
+ model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
+ logger.info("Loading MPS for Mac M Series")
+ self.device = "mps"
+ self.pipeline = StableDiffusion3Pipeline.from_pretrained(
+ model_path,
+ torch_dtype=torch.bfloat16,
+ ).to(device=self.device)
+ else:
+ raise Exception("No CUDA or MPS device available")
+
+
+class ModelPipelineInitializer:
+ def __init__(self, model: str = "", type_models: str = "t2im"):
+ self.model = model
+ self.type_models = type_models
+ self.pipeline = None
+ self.device = "cuda" if torch.cuda.is_available() else "mps"
+ self.model_type = None
+
+ def initialize_pipeline(self):
+ if not self.model:
+ raise ValueError("Model name not provided")
+
+ # Check if model exists in PresetModels
+ preset_models = PresetModels()
+
+ # Determine which model type we're dealing with
+ if self.model in preset_models.SD3:
+ self.model_type = "SD3"
+ elif self.model in preset_models.SD3_5:
+ self.model_type = "SD3_5"
+
+ # Create appropriate pipeline based on model type and type_models
+ if self.type_models == "t2im":
+ if self.model_type in ["SD3", "SD3_5"]:
+ self.pipeline = TextToImagePipelineSD3(self.model)
+ else:
+ raise ValueError(f"Model type {self.model_type} not supported for text-to-image")
+ elif self.type_models == "t2v":
+ raise ValueError(f"Unsupported type_models: {self.type_models}")
+
+ return self.pipeline
diff --git a/examples/server-async/README.md b/examples/server-async/README.md
new file mode 100644
index 0000000000..a47ab7c7f2
--- /dev/null
+++ b/examples/server-async/README.md
@@ -0,0 +1,171 @@
+# Asynchronous server and parallel execution of models
+
+> Example/demo server that keeps a single model in memory while safely running parallel inference requests by creating per-request lightweight views and cloning only small, stateful components (schedulers, RNG state, small mutable attrs). Works with StableDiffusion3 pipelines.
+> We recommend running 10 to 50 inferences in parallel for optimal performance, averaging between 25 and 30 seconds to 1 minute and 1 minute and 30 seconds. (This is only recommended if you have a GPU with 35GB of VRAM or more; otherwise, keep it to one or two inferences in parallel to avoid decoding or saving errors due to memory shortages.)
+
+## ⚠️ IMPORTANT
+
+* The example demonstrates how to run pipelines like `StableDiffusion3-3.5` concurrently while keeping a single copy of the heavy model parameters on GPU.
+
+## Necessary components
+
+All the components needed to create the inference server are in the current directory:
+
+```
+server-async/
+├── utils/
+├─────── __init__.py
+├─────── scheduler.py # BaseAsyncScheduler wrapper and async_retrieve_timesteps for secure inferences
+├─────── requestscopedpipeline.py # RequestScoped Pipeline for inference with a single in-memory model
+├─────── utils.py # Image/video saving utilities and service configuration
+├── Pipelines.py # pipeline loader classes (SD3)
+├── serverasync.py # FastAPI app with lifespan management and async inference endpoints
+├── test.py # Client test script for inference requests
+├── requirements.txt # Dependencies
+└── README.md # This documentation
+```
+
+## What `diffusers-async` adds / Why we needed it
+
+Core problem: a naive server that calls `pipe.__call__` concurrently can hit **race conditions** (e.g., `scheduler.set_timesteps` mutates shared state) or explode memory by deep-copying the whole pipeline per-request.
+
+`diffusers-async` / this example addresses that by:
+
+* **Request-scoped views**: `RequestScopedPipeline` creates a shallow copy of the pipeline per request so heavy weights (UNet, VAE, text encoder) remain shared and *are not duplicated*.
+* **Per-request mutable state**: stateful small objects (scheduler, RNG state, small lists/dicts, callbacks) are cloned per request. The system uses `BaseAsyncScheduler.clone_for_request(...)` for scheduler cloning, with fallback to safe `deepcopy` or other heuristics.
+* **Tokenizer concurrency safety**: `RequestScopedPipeline` now manages an internal tokenizer lock with automatic tokenizer detection and wrapping. This ensures that Rust tokenizers are safe to use under concurrency — race condition errors like `Already borrowed` no longer occur.
+* **`async_retrieve_timesteps(..., return_scheduler=True)`**: fully retro-compatible helper that returns `(timesteps, num_inference_steps, scheduler)` without mutating the shared scheduler. For users not using `return_scheduler=True`, the behavior is identical to the original API.
+* **Robust attribute handling**: wrapper avoids writing to read-only properties (e.g., `components`) and auto-detects small mutable attributes to clone while avoiding duplication of large tensors. Configurable tensor size threshold prevents cloning of large tensors.
+* **Enhanced scheduler wrapping**: `BaseAsyncScheduler` automatically wraps schedulers with improved `__getattr__`, `__setattr__`, and debugging methods (`__repr__`, `__str__`).
+
+## How the server works (high-level flow)
+
+1. **Single model instance** is loaded into memory (GPU/MPS) when the server starts.
+2. On each HTTP inference request:
+
+ * The server uses `RequestScopedPipeline.generate(...)` which:
+
+ * automatically wraps the base scheduler in `BaseAsyncScheduler` (if not already wrapped),
+ * obtains a *local scheduler* (via `clone_for_request()` or `deepcopy`),
+ * does `local_pipe = copy.copy(base_pipe)` (shallow copy),
+ * sets `local_pipe.scheduler = local_scheduler` (if possible),
+ * clones only small mutable attributes (callbacks, rng, small latents) with auto-detection,
+ * wraps tokenizers with thread-safe locks to prevent race conditions,
+ * optionally enters a `model_cpu_offload_context()` for memory offload hooks,
+ * calls the pipeline on the local view (`local_pipe(...)`).
+3. **Result**: inference completes, images are moved to CPU & saved (if requested), internal buffers freed (GC + `torch.cuda.empty_cache()`).
+4. Multiple requests can run in parallel while sharing heavy weights and isolating mutable state.
+
+## How to set up and run the server
+
+### 1) Install dependencies
+
+Recommended: create a virtualenv / conda environment.
+
+```bash
+pip install diffusers
+pip install -r requirements.txt
+```
+
+### 2) Start the server
+
+Using the `serverasync.py` file that already has everything you need:
+
+```bash
+python serverasync.py
+```
+
+The server will start on `http://localhost:8500` by default with the following features:
+- FastAPI application with async lifespan management
+- Automatic model loading and pipeline initialization
+- Request counting and active inference tracking
+- Memory cleanup after each inference
+- CORS middleware for cross-origin requests
+
+### 3) Test the server
+
+Use the included test script:
+
+```bash
+python test.py
+```
+
+Or send a manual request:
+
+`POST /api/diffusers/inference` with JSON body:
+
+```json
+{
+ "prompt": "A futuristic cityscape, vibrant colors",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1
+}
+```
+
+Response example:
+
+```json
+{
+ "response": ["http://localhost:8500/images/img123.png"]
+}
+```
+
+### 4) Server endpoints
+
+- `GET /` - Welcome message
+- `POST /api/diffusers/inference` - Main inference endpoint
+- `GET /images/{filename}` - Serve generated images
+- `GET /api/status` - Server status and memory info
+
+## Advanced Configuration
+
+### RequestScopedPipeline Parameters
+
+```python
+RequestScopedPipeline(
+ pipeline, # Base pipeline to wrap
+ mutable_attrs=None, # Custom list of attributes to clone
+ auto_detect_mutables=True, # Enable automatic detection of mutable attributes
+ tensor_numel_threshold=1_000_000, # Tensor size threshold for cloning
+ tokenizer_lock=None, # Custom threading lock for tokenizers
+ wrap_scheduler=True # Auto-wrap scheduler in BaseAsyncScheduler
+)
+```
+
+### BaseAsyncScheduler Features
+
+* Transparent proxy to the original scheduler with `__getattr__` and `__setattr__`
+* `clone_for_request()` method for safe per-request scheduler cloning
+* Enhanced debugging with `__repr__` and `__str__` methods
+* Full compatibility with existing scheduler APIs
+
+### Server Configuration
+
+The server configuration can be modified in `serverasync.py` through the `ServerConfigModels` dataclass:
+
+```python
+@dataclass
+class ServerConfigModels:
+ model: str = 'stabilityai/stable-diffusion-3.5-medium'
+ type_models: str = 't2im'
+ host: str = '0.0.0.0'
+ port: int = 8500
+```
+
+## Troubleshooting (quick)
+
+* `Already borrowed` — previously a Rust tokenizer concurrency error.
+ ✅ This is now fixed: `RequestScopedPipeline` automatically detects and wraps tokenizers with thread locks, so race conditions no longer happen.
+
+* `can't set attribute 'components'` — pipeline exposes read-only `components`.
+ ✅ The RequestScopedPipeline now detects read-only properties and skips setting them automatically.
+
+* Scheduler issues:
+ * If the scheduler doesn't implement `clone_for_request` and `deepcopy` fails, we log and fallback — but prefer `async_retrieve_timesteps(..., return_scheduler=True)` to avoid mutating the shared scheduler.
+ ✅ Note: `async_retrieve_timesteps` is fully retro-compatible — if you don't pass `return_scheduler=True`, the behavior is unchanged.
+
+* Memory issues with large tensors:
+ ✅ The system now has configurable `tensor_numel_threshold` to prevent cloning of large tensors while still cloning small mutable ones.
+
+* Automatic tokenizer detection:
+ ✅ The system automatically identifies tokenizer components by checking for tokenizer methods, class names, and attributes, then applies thread-safe wrappers.
\ No newline at end of file
diff --git a/examples/server-async/requirements.txt b/examples/server-async/requirements.txt
new file mode 100644
index 0000000000..aafa93b702
--- /dev/null
+++ b/examples/server-async/requirements.txt
@@ -0,0 +1,10 @@
+torch
+torchvision
+transformers
+sentencepiece
+fastapi
+uvicorn
+ftfy
+accelerate
+xformers
+protobuf
\ No newline at end of file
diff --git a/examples/server-async/serverasync.py b/examples/server-async/serverasync.py
new file mode 100644
index 0000000000..b279b36f9a
--- /dev/null
+++ b/examples/server-async/serverasync.py
@@ -0,0 +1,230 @@
+import asyncio
+import gc
+import logging
+import os
+import random
+import threading
+from contextlib import asynccontextmanager
+from dataclasses import dataclass
+from typing import Any, Dict, Optional, Type
+
+import torch
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.concurrency import run_in_threadpool
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import FileResponse
+from Pipelines import ModelPipelineInitializer
+from pydantic import BaseModel
+
+from utils import RequestScopedPipeline, Utils
+
+
+@dataclass
+class ServerConfigModels:
+ model: str = "stabilityai/stable-diffusion-3.5-medium"
+ type_models: str = "t2im"
+ constructor_pipeline: Optional[Type] = None
+ custom_pipeline: Optional[Type] = None
+ components: Optional[Dict[str, Any]] = None
+ torch_dtype: Optional[torch.dtype] = None
+ host: str = "0.0.0.0"
+ port: int = 8500
+
+
+server_config = ServerConfigModels()
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ logging.basicConfig(level=logging.INFO)
+ app.state.logger = logging.getLogger("diffusers-server")
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128,expandable_segments:True"
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
+
+ app.state.total_requests = 0
+ app.state.active_inferences = 0
+ app.state.metrics_lock = asyncio.Lock()
+ app.state.metrics_task = None
+
+ app.state.utils_app = Utils(
+ host=server_config.host,
+ port=server_config.port,
+ )
+
+ async def metrics_loop():
+ try:
+ while True:
+ async with app.state.metrics_lock:
+ total = app.state.total_requests
+ active = app.state.active_inferences
+ app.state.logger.info(f"[METRICS] total_requests={total} active_inferences={active}")
+ await asyncio.sleep(5)
+ except asyncio.CancelledError:
+ app.state.logger.info("Metrics loop cancelled")
+ raise
+
+ app.state.metrics_task = asyncio.create_task(metrics_loop())
+
+ try:
+ yield
+ finally:
+ task = app.state.metrics_task
+ if task:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+ try:
+ stop_fn = getattr(model_pipeline, "stop", None) or getattr(model_pipeline, "close", None)
+ if callable(stop_fn):
+ await run_in_threadpool(stop_fn)
+ except Exception as e:
+ app.state.logger.warning(f"Error during pipeline shutdown: {e}")
+
+ app.state.logger.info("Lifespan shutdown complete")
+
+
+app = FastAPI(lifespan=lifespan)
+
+logger = logging.getLogger("DiffusersServer.Pipelines")
+
+
+initializer = ModelPipelineInitializer(
+ model=server_config.model,
+ type_models=server_config.type_models,
+)
+model_pipeline = initializer.initialize_pipeline()
+model_pipeline.start()
+
+request_pipe = RequestScopedPipeline(model_pipeline.pipeline)
+pipeline_lock = threading.Lock()
+
+logger.info(f"Pipeline initialized and ready to receive requests (model ={server_config.model})")
+
+app.state.MODEL_INITIALIZER = initializer
+app.state.MODEL_PIPELINE = model_pipeline
+app.state.REQUEST_PIPE = request_pipe
+app.state.PIPELINE_LOCK = pipeline_lock
+
+
+class JSONBodyQueryAPI(BaseModel):
+ model: str | None = None
+ prompt: str
+ negative_prompt: str | None = None
+ num_inference_steps: int = 28
+ num_images_per_prompt: int = 1
+
+
+@app.middleware("http")
+async def count_requests_middleware(request: Request, call_next):
+ async with app.state.metrics_lock:
+ app.state.total_requests += 1
+ response = await call_next(request)
+ return response
+
+
+@app.get("/")
+async def root():
+ return {"message": "Welcome to the Diffusers Server"}
+
+
+@app.post("/api/diffusers/inference")
+async def api(json: JSONBodyQueryAPI):
+ prompt = json.prompt
+ negative_prompt = json.negative_prompt or ""
+ num_steps = json.num_inference_steps
+ num_images_per_prompt = json.num_images_per_prompt
+
+ wrapper = app.state.MODEL_PIPELINE
+ initializer = app.state.MODEL_INITIALIZER
+
+ utils_app = app.state.utils_app
+
+ if not wrapper or not wrapper.pipeline:
+ raise HTTPException(500, "Model not initialized correctly")
+ if not prompt.strip():
+ raise HTTPException(400, "No prompt provided")
+
+ def make_generator():
+ g = torch.Generator(device=initializer.device)
+ return g.manual_seed(random.randint(0, 10_000_000))
+
+ req_pipe = app.state.REQUEST_PIPE
+
+ def infer():
+ gen = make_generator()
+ return req_pipe.generate(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ generator=gen,
+ num_inference_steps=num_steps,
+ num_images_per_prompt=num_images_per_prompt,
+ device=initializer.device,
+ output_type="pil",
+ )
+
+ try:
+ async with app.state.metrics_lock:
+ app.state.active_inferences += 1
+
+ output = await run_in_threadpool(infer)
+
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+
+ urls = [utils_app.save_image(img) for img in output.images]
+ return {"response": urls}
+
+ except Exception as e:
+ async with app.state.metrics_lock:
+ app.state.active_inferences = max(0, app.state.active_inferences - 1)
+ logger.error(f"Error during inference: {e}")
+ raise HTTPException(500, f"Error in processing: {e}")
+
+ finally:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_peak_memory_stats()
+ torch.cuda.ipc_collect()
+ gc.collect()
+
+
+@app.get("/images/{filename}")
+async def serve_image(filename: str):
+ utils_app = app.state.utils_app
+ file_path = os.path.join(utils_app.image_dir, filename)
+ if not os.path.isfile(file_path):
+ raise HTTPException(status_code=404, detail="Image not found")
+ return FileResponse(file_path, media_type="image/png")
+
+
+@app.get("/api/status")
+async def get_status():
+ memory_info = {}
+ if torch.cuda.is_available():
+ memory_allocated = torch.cuda.memory_allocated() / 1024**3 # GB
+ memory_reserved = torch.cuda.memory_reserved() / 1024**3 # GB
+ memory_info = {
+ "memory_allocated_gb": round(memory_allocated, 2),
+ "memory_reserved_gb": round(memory_reserved, 2),
+ "device": torch.cuda.get_device_name(0),
+ }
+
+ return {"current_model": server_config.model, "type_models": server_config.type_models, "memory": memory_info}
+
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+if __name__ == "__main__":
+ import uvicorn
+
+ uvicorn.run(app, host=server_config.host, port=server_config.port)
diff --git a/examples/server-async/test.py b/examples/server-async/test.py
new file mode 100644
index 0000000000..e67317ea8f
--- /dev/null
+++ b/examples/server-async/test.py
@@ -0,0 +1,65 @@
+import os
+import time
+import urllib.parse
+
+import requests
+
+
+SERVER_URL = "http://localhost:8500/api/diffusers/inference"
+BASE_URL = "http://localhost:8500"
+DOWNLOAD_FOLDER = "generated_images"
+WAIT_BEFORE_DOWNLOAD = 2 # seconds
+
+os.makedirs(DOWNLOAD_FOLDER, exist_ok=True)
+
+
+def save_from_url(url: str) -> str:
+ """Download the given URL (relative or absolute) and save it locally."""
+ if url.startswith("/"):
+ direct = BASE_URL.rstrip("/") + url
+ else:
+ direct = url
+ resp = requests.get(direct, timeout=60)
+ resp.raise_for_status()
+ filename = os.path.basename(urllib.parse.urlparse(direct).path) or f"img_{int(time.time())}.png"
+ path = os.path.join(DOWNLOAD_FOLDER, filename)
+ with open(path, "wb") as f:
+ f.write(resp.content)
+ return path
+
+
+def main():
+ payload = {
+ "prompt": "The T-800 Terminator Robot Returning From The Future, Anime Style",
+ "num_inference_steps": 30,
+ "num_images_per_prompt": 1,
+ }
+
+ print("Sending request...")
+ try:
+ r = requests.post(SERVER_URL, json=payload, timeout=480)
+ r.raise_for_status()
+ except Exception as e:
+ print(f"Request failed: {e}")
+ return
+
+ body = r.json().get("response", [])
+ # Normalize to a list
+ urls = body if isinstance(body, list) else [body] if body else []
+ if not urls:
+ print("No URLs found in the response. Check the server output.")
+ return
+
+ print(f"Received {len(urls)} URL(s). Waiting {WAIT_BEFORE_DOWNLOAD}s before downloading...")
+ time.sleep(WAIT_BEFORE_DOWNLOAD)
+
+ for u in urls:
+ try:
+ path = save_from_url(u)
+ print(f"Image saved to: {path}")
+ except Exception as e:
+ print(f"Error downloading {u}: {e}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/server-async/utils/__init__.py b/examples/server-async/utils/__init__.py
new file mode 100644
index 0000000000..731cfe491a
--- /dev/null
+++ b/examples/server-async/utils/__init__.py
@@ -0,0 +1,2 @@
+from .requestscopedpipeline import RequestScopedPipeline
+from .utils import Utils
diff --git a/examples/server-async/utils/requestscopedpipeline.py b/examples/server-async/utils/requestscopedpipeline.py
new file mode 100644
index 0000000000..57d1e25671
--- /dev/null
+++ b/examples/server-async/utils/requestscopedpipeline.py
@@ -0,0 +1,296 @@
+import copy
+import threading
+from typing import Any, Iterable, List, Optional
+
+import torch
+
+from diffusers.utils import logging
+
+from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
+
+
+logger = logging.get_logger(__name__)
+
+
+def safe_tokenize(tokenizer, *args, lock, **kwargs):
+ with lock:
+ return tokenizer(*args, **kwargs)
+
+
+class RequestScopedPipeline:
+ DEFAULT_MUTABLE_ATTRS = [
+ "_all_hooks",
+ "_offload_device",
+ "_progress_bar_config",
+ "_progress_bar",
+ "_rng_state",
+ "_last_seed",
+ "latents",
+ ]
+
+ def __init__(
+ self,
+ pipeline: Any,
+ mutable_attrs: Optional[Iterable[str]] = None,
+ auto_detect_mutables: bool = True,
+ tensor_numel_threshold: int = 1_000_000,
+ tokenizer_lock: Optional[threading.Lock] = None,
+ wrap_scheduler: bool = True,
+ ):
+ self._base = pipeline
+ self.unet = getattr(pipeline, "unet", None)
+ self.vae = getattr(pipeline, "vae", None)
+ self.text_encoder = getattr(pipeline, "text_encoder", None)
+ self.components = getattr(pipeline, "components", None)
+
+ if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
+ if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
+ pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
+
+ self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
+ self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
+
+ self._auto_detect_mutables = bool(auto_detect_mutables)
+ self._tensor_numel_threshold = int(tensor_numel_threshold)
+
+ self._auto_detected_attrs: List[str] = []
+
+ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
+ base_sched = getattr(self._base, "scheduler", None)
+ if base_sched is None:
+ return None
+
+ if not isinstance(base_sched, BaseAsyncScheduler):
+ wrapped_scheduler = BaseAsyncScheduler(base_sched)
+ else:
+ wrapped_scheduler = base_sched
+
+ try:
+ return wrapped_scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps, device=device, **clone_kwargs
+ )
+ except Exception as e:
+ logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
+ try:
+ return copy.deepcopy(wrapped_scheduler)
+ except Exception as e:
+ logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
+ return wrapped_scheduler
+
+ def _autodetect_mutables(self, max_attrs: int = 40):
+ if not self._auto_detect_mutables:
+ return []
+
+ if self._auto_detected_attrs:
+ return self._auto_detected_attrs
+
+ candidates: List[str] = []
+ seen = set()
+ for name in dir(self._base):
+ if name.startswith("__"):
+ continue
+ if name in self._mutable_attrs:
+ continue
+ if name in ("to", "save_pretrained", "from_pretrained"):
+ continue
+ try:
+ val = getattr(self._base, name)
+ except Exception:
+ continue
+
+ import types
+
+ # skip callables and modules
+ if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
+ continue
+
+ # containers -> candidate
+ if isinstance(val, (dict, list, set, tuple, bytearray)):
+ candidates.append(name)
+ seen.add(name)
+ else:
+ # try Tensor detection
+ try:
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ candidates.append(name)
+ seen.add(name)
+ else:
+ logger.debug(f"Ignoring large tensor attr '{name}', numel={val.numel()}")
+ except Exception:
+ continue
+
+ if len(candidates) >= max_attrs:
+ break
+
+ self._auto_detected_attrs = candidates
+ logger.debug(f"Autodetected mutable attrs to clone: {self._auto_detected_attrs}")
+ return self._auto_detected_attrs
+
+ def _is_readonly_property(self, base_obj, attr_name: str) -> bool:
+ try:
+ cls = type(base_obj)
+ descriptor = getattr(cls, attr_name, None)
+ if isinstance(descriptor, property):
+ return descriptor.fset is None
+ if hasattr(descriptor, "__set__") is False and descriptor is not None:
+ return False
+ except Exception:
+ pass
+ return False
+
+ def _clone_mutable_attrs(self, base, local):
+ attrs_to_clone = list(self._mutable_attrs)
+ attrs_to_clone.extend(self._autodetect_mutables())
+
+ EXCLUDE_ATTRS = {
+ "components",
+ }
+
+ for attr in attrs_to_clone:
+ if attr in EXCLUDE_ATTRS:
+ logger.debug(f"Skipping excluded attr '{attr}'")
+ continue
+ if not hasattr(base, attr):
+ continue
+ if self._is_readonly_property(base, attr):
+ logger.debug(f"Skipping read-only property '{attr}'")
+ continue
+
+ try:
+ val = getattr(base, attr)
+ except Exception as e:
+ logger.debug(f"Could not getattr('{attr}') on base pipeline: {e}")
+ continue
+
+ try:
+ if isinstance(val, dict):
+ setattr(local, attr, dict(val))
+ elif isinstance(val, (list, tuple, set)):
+ setattr(local, attr, list(val))
+ elif isinstance(val, bytearray):
+ setattr(local, attr, bytearray(val))
+ else:
+ # small tensors or atomic values
+ if isinstance(val, torch.Tensor):
+ if val.numel() <= self._tensor_numel_threshold:
+ setattr(local, attr, val.clone())
+ else:
+ # don't clone big tensors, keep reference
+ setattr(local, attr, val)
+ else:
+ try:
+ setattr(local, attr, copy.copy(val))
+ except Exception:
+ setattr(local, attr, val)
+ except (AttributeError, TypeError) as e:
+ logger.debug(f"Skipping cloning attribute '{attr}' because it is not settable: {e}")
+ continue
+ except Exception as e:
+ logger.debug(f"Unexpected error cloning attribute '{attr}': {e}")
+ continue
+
+ def _is_tokenizer_component(self, component) -> bool:
+ if component is None:
+ return False
+
+ tokenizer_methods = ["encode", "decode", "tokenize", "__call__"]
+ has_tokenizer_methods = any(hasattr(component, method) for method in tokenizer_methods)
+
+ class_name = component.__class__.__name__.lower()
+ has_tokenizer_in_name = "tokenizer" in class_name
+
+ tokenizer_attrs = ["vocab_size", "pad_token", "eos_token", "bos_token"]
+ has_tokenizer_attrs = any(hasattr(component, attr) for attr in tokenizer_attrs)
+
+ return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
+
+ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
+ local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
+
+ try:
+ local_pipe = copy.copy(self._base)
+ except Exception as e:
+ logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
+ local_pipe = copy.deepcopy(self._base)
+
+ if local_scheduler is not None:
+ try:
+ timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
+ local_scheduler.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ return_scheduler=True,
+ **{k: v for k, v in kwargs.items() if k in ["timesteps", "sigmas"]},
+ )
+
+ final_scheduler = BaseAsyncScheduler(configured_scheduler)
+ setattr(local_pipe, "scheduler", final_scheduler)
+ except Exception:
+ logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
+
+ self._clone_mutable_attrs(self._base, local_pipe)
+
+ # 4) wrap tokenizers on the local pipe with the lock wrapper
+ tokenizer_wrappers = {} # name -> original_tokenizer
+ try:
+ # a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
+ for name in dir(local_pipe):
+ if "tokenizer" in name and not name.startswith("_"):
+ tok = getattr(local_pipe, name, None)
+ if tok is not None and self._is_tokenizer_component(tok):
+ tokenizer_wrappers[name] = tok
+ setattr(
+ local_pipe,
+ name,
+ lambda *args, tok=tok, **kwargs: safe_tokenize(
+ tok, *args, lock=self._tokenizer_lock, **kwargs
+ ),
+ )
+
+ # b) wrap tokenizers in components dict
+ if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
+ for key, val in local_pipe.components.items():
+ if val is None:
+ continue
+
+ if self._is_tokenizer_component(val):
+ tokenizer_wrappers[f"components[{key}]"] = val
+ local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
+ tokenizer, *args, lock=self._tokenizer_lock, **kwargs
+ )
+
+ except Exception as e:
+ logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
+
+ result = None
+ cm = getattr(local_pipe, "model_cpu_offload_context", None)
+ try:
+ if callable(cm):
+ try:
+ with cm():
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except TypeError:
+ # cm might be a context manager instance rather than callable
+ try:
+ with cm:
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ except Exception as e:
+ logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+ else:
+ # no offload context available — call directly
+ result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
+
+ return result
+
+ finally:
+ try:
+ for name, tok in tokenizer_wrappers.items():
+ if name.startswith("components["):
+ key = name[len("components[") : -1]
+ local_pipe.components[key] = tok
+ else:
+ setattr(local_pipe, name, tok)
+ except Exception as e:
+ logger.debug(f"Error restoring wrapped tokenizers: {e}")
diff --git a/examples/server-async/utils/scheduler.py b/examples/server-async/utils/scheduler.py
new file mode 100644
index 0000000000..86d47cac61
--- /dev/null
+++ b/examples/server-async/utils/scheduler.py
@@ -0,0 +1,141 @@
+import copy
+import inspect
+from typing import Any, List, Optional, Union
+
+import torch
+
+
+class BaseAsyncScheduler:
+ def __init__(self, scheduler: Any):
+ self.scheduler = scheduler
+
+ def __getattr__(self, name: str):
+ if hasattr(self.scheduler, name):
+ return getattr(self.scheduler, name)
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ def __setattr__(self, name: str, value):
+ if name == "scheduler":
+ super().__setattr__(name, value)
+ else:
+ if hasattr(self, "scheduler") and hasattr(self.scheduler, name):
+ setattr(self.scheduler, name, value)
+ else:
+ super().__setattr__(name, value)
+
+ def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
+ local = copy.deepcopy(self.scheduler)
+ local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
+ cloned = self.__class__(local)
+ return cloned
+
+ def __repr__(self):
+ return f"BaseAsyncScheduler({repr(self.scheduler)})"
+
+ def __str__(self):
+ return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
+
+
+def async_retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call.
+ Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Backwards compatible: by default the function behaves exactly as before and returns
+ (timesteps_tensor, num_inference_steps)
+
+ If the caller passes `return_scheduler=True` in kwargs, the function will **not** mutate the passed
+ scheduler. Instead it will use a cloned scheduler if available (via `scheduler.clone_for_request`)
+ or a deepcopy fallback, call `set_timesteps` on that cloned scheduler, and return:
+ (timesteps_tensor, num_inference_steps, scheduler_in_use)
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Optional kwargs:
+ return_scheduler (bool, default False): if True, return (timesteps, num_inference_steps, scheduler_in_use)
+ where `scheduler_in_use` is a scheduler instance that already has timesteps set.
+ This mode will prefer `scheduler.clone_for_request(...)` if available, to avoid mutating the original scheduler.
+
+ Returns:
+ `(timesteps_tensor, num_inference_steps)` by default (backwards compatible), or
+ `(timesteps_tensor, num_inference_steps, scheduler_in_use)` if `return_scheduler=True`.
+ """
+ # pop our optional control kwarg (keeps compatibility)
+ return_scheduler = bool(kwargs.pop("return_scheduler", False))
+
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+
+ # choose scheduler to call set_timesteps on
+ scheduler_in_use = scheduler
+ if return_scheduler:
+ # Do not mutate the provided scheduler: prefer to clone if possible
+ if hasattr(scheduler, "clone_for_request"):
+ try:
+ # clone_for_request may accept num_inference_steps or other kwargs; be permissive
+ scheduler_in_use = scheduler.clone_for_request(
+ num_inference_steps=num_inference_steps or 0, device=device
+ )
+ except Exception:
+ scheduler_in_use = copy.deepcopy(scheduler)
+ else:
+ # fallback deepcopy (scheduler tends to be smallish - acceptable)
+ scheduler_in_use = copy.deepcopy(scheduler)
+
+ # helper to test if set_timesteps supports a particular kwarg
+ def _accepts(param_name: str) -> bool:
+ try:
+ return param_name in set(inspect.signature(scheduler_in_use.set_timesteps).parameters.keys())
+ except (ValueError, TypeError):
+ # if signature introspection fails, be permissive and attempt the call later
+ return False
+
+ # now call set_timesteps on the chosen scheduler_in_use (may be original or clone)
+ if timesteps is not None:
+ accepts_timesteps = _accepts("timesteps")
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ elif sigmas is not None:
+ accept_sigmas = _accepts("sigmas")
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler_in_use.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler_in_use.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+ num_inference_steps = len(timesteps_out)
+ else:
+ # default path
+ scheduler_in_use.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps_out = scheduler_in_use.timesteps
+
+ if return_scheduler:
+ return timesteps_out, num_inference_steps, scheduler_in_use
+ return timesteps_out, num_inference_steps
diff --git a/examples/server-async/utils/utils.py b/examples/server-async/utils/utils.py
new file mode 100644
index 0000000000..9f94330512
--- /dev/null
+++ b/examples/server-async/utils/utils.py
@@ -0,0 +1,48 @@
+import gc
+import logging
+import os
+import tempfile
+import uuid
+
+import torch
+
+
+logger = logging.getLogger(__name__)
+
+
+class Utils:
+ def __init__(self, host: str = "0.0.0.0", port: int = 8500):
+ self.service_url = f"http://{host}:{port}"
+ self.image_dir = os.path.join(tempfile.gettempdir(), "images")
+ if not os.path.exists(self.image_dir):
+ os.makedirs(self.image_dir)
+
+ self.video_dir = os.path.join(tempfile.gettempdir(), "videos")
+ if not os.path.exists(self.video_dir):
+ os.makedirs(self.video_dir)
+
+ def save_image(self, image):
+ if hasattr(image, "to"):
+ try:
+ image = image.to("cpu")
+ except Exception:
+ pass
+
+ if isinstance(image, torch.Tensor):
+ from torchvision import transforms
+
+ to_pil = transforms.ToPILImage()
+ image = to_pil(image.squeeze(0).clamp(0, 1))
+
+ filename = "img" + str(uuid.uuid4()).split("-")[0] + ".png"
+ image_path = os.path.join(self.image_dir, filename)
+ logger.info(f"Saving image to {image_path}")
+
+ image.save(image_path, format="PNG", optimize=True)
+
+ del image
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return os.path.join(self.service_url, "images", filename)
diff --git a/examples/server/README.md b/examples/server/README.md
index 8ad0ed3cbe..f8cd58fc1c 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -9,8 +9,8 @@ This guide will show you how to use the [`StableDiffusion3Pipeline`] in a server
Start by navigating to the `examples/server` folder and installing all of the dependencies.
```py
-pip install .
-pip install -f requirements.txt
+pip install diffusers
+pip install -r requirements.txt
```
Launch the server with the following command.
diff --git a/examples/server/requirements.in b/examples/server/requirements.in
index a469569a10..f8c35d48cd 100644
--- a/examples/server/requirements.in
+++ b/examples/server/requirements.in
@@ -6,4 +6,5 @@ py-consul
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
fastapi
-uvicorn
\ No newline at end of file
+uvicorn
+accelerate
diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt
index b91a8861a0..688a4ee94f 100644
--- a/examples/server/requirements.txt
+++ b/examples/server/requirements.txt
@@ -39,7 +39,7 @@ fsspec==2024.10.0
# torch
h11==0.14.0
# via uvicorn
-huggingface-hub==0.26.1
+huggingface-hub==0.35.0
# via
# tokenizers
# transformers
diff --git a/examples/text_to_image/requirements.txt b/examples/text_to_image/requirements.txt
index c3ffa42f0e..be05fe3fcd 100644
--- a/examples/text_to_image/requirements.txt
+++ b/examples/text_to_image/requirements.txt
@@ -5,4 +5,4 @@ datasets>=2.19.1
ftfy
tensorboard
Jinja2
-peft==0.7.0
+peft>=0.17.0
diff --git a/examples/text_to_image/requirements_sdxl.txt b/examples/text_to_image/requirements_sdxl.txt
index 64cbc9205f..4dacc26ce4 100644
--- a/examples/text_to_image/requirements_sdxl.txt
+++ b/examples/text_to_image/requirements_sdxl.txt
@@ -5,4 +5,4 @@ ftfy
tensorboard
Jinja2
datasets
-peft==0.7.0
\ No newline at end of file
+peft>=0.17.0
\ No newline at end of file
diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md
index 22f982509b..6f8276a632 100644
--- a/examples/unconditional_image_generation/README.md
+++ b/examples/unconditional_image_generation/README.md
@@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
+If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
+
Below, we explain both in more detail.
#### Provide the dataset as a folder
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index 3ffeef1364..0cc96220b9 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape)
+def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
+ """
+ if tensor.ndim == 2:
+ tensor = tensor.unsqueeze(0)
+ channels = tensor.shape[0]
+ if channels == 3:
+ return tensor
+ if channels == 1:
+ return tensor.repeat(3, 1, 1)
+ if channels == 2:
+ return torch.cat([tensor, tensor[:1]], dim=0)
+ if channels > 3:
+ return tensor[:3]
+ raise ValueError(f"Unsupported number of channels: {channels}")
+
+
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
@@ -260,6 +278,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
+ parser.add_argument(
+ "--preserve_input_precision",
+ action="store_true",
+ help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
+ )
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -453,19 +476,41 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
+ spatial_augmentations = [
+ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
+ transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ ]
+
augmentations = transforms.Compose(
- [
- transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
- transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
- transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
+ spatial_augmentations
+ + [
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
+ precision_augmentations = transforms.Compose(
+ [
+ transforms.PILToTensor(),
+ transforms.Lambda(_ensure_three_channels),
+ transforms.ConvertImageDtype(torch.float32),
+ ]
+ + spatial_augmentations
+ + [transforms.Normalize([0.5], [0.5])]
+ )
+
def transform_images(examples):
- images = [augmentations(image.convert("RGB")) for image in examples["image"]]
- return {"input": images}
+ processed = []
+ for image in examples["image"]:
+ if not args.preserve_input_precision:
+ processed.append(augmentations(image.convert("RGB")))
+ else:
+ precise_image = image
+ if precise_image.mode == "P":
+ precise_image = precise_image.convert("RGB")
+ processed.append(precision_augmentations(precise_image))
+ return {"input": processed}
logger.info(f"Dataset size: {len(dataset)}")
diff --git a/scripts/convert_consistency_decoder.py b/scripts/convert_consistency_decoder.py
index 629c784c09..9e28945775 100644
--- a/scripts/convert_consistency_decoder.py
+++ b/scripts/convert_consistency_decoder.py
@@ -24,7 +24,8 @@ args = args.parse_args()
def _extract_into_tensor(arr, timesteps, broadcast_shape):
- # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895 """
+ # from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/gaussian_diffusion.py#L895
+ # """
res = arr[timesteps].float()
dims_to_append = len(broadcast_shape) - len(res.shape)
return res[(...,) + (None,) * dims_to_append]
@@ -507,7 +508,9 @@ def rename_state_dict(sd, embedding):
# encode with stable diffusion vae
-pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+pipe = StableDiffusionPipeline.from_pretrained(
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
+)
pipe.vae.cuda()
# construct original decoder with jitted model
@@ -1090,7 +1093,7 @@ def new_constructor(self, **kwargs):
Encoder.__init__ = new_constructor
-vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
+vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
consistency_vae = ConsistencyDecoderVAE(
encoder_args=vae.encoder.constructor_arguments,
decoder_args=unet.config,
@@ -1117,7 +1120,7 @@ print((sample_consistency_orig - sample_consistency_new_3).abs().sum())
print("running with diffusers pipeline")
pipe = DiffusionPipeline.from_pretrained(
- "runwayml/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
+ "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=consistency_vae, torch_dtype=torch.float16
)
pipe.to("cuda")
diff --git a/scripts/convert_hunyuan_image_to_diffusers.py b/scripts/convert_hunyuan_image_to_diffusers.py
new file mode 100644
index 0000000000..c41e934cc3
--- /dev/null
+++ b/scripts/convert_hunyuan_image_to_diffusers.py
@@ -0,0 +1,1044 @@
+import argparse
+import logging
+
+import torch
+from safetensors import safe_open
+
+from diffusers import AutoencoderKLHunyuanImage, AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
+
+
+logger = logging.getLogger(__name__) # pylint: disable=invalid-name
+
+
+"""
+Usage examples
+==============
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage2.1 \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1.safetensors" \
+ --vae_checkpoint_path "HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
+ --output_path "/raid/yiyi/test-hy21-diffusers" \
+ --dtype fp32
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage2.1-distilled \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage2.1-distilled.safetensors" \
+ --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_2_1/pytorch_model.ckpt" \
+ --output_path "/raid/yiyi/test-hy21-distilled-diffusers" \
+ --dtype fp32
+
+
+python scripts/convert_hunyuan_image_to_diffusers.py \
+ --model_type hunyuanimage-refiner \
+ --transformer_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/dit/hunyuanimage-refiner.safetensors" \
+ --vae_checkpoint_path "/raid/yiyi/HunyuanImage-2.1/ckpts/vae/vae_refiner/pytorch_model.pt" \
+ --output_path "/raid/yiyi/test-hy2-refiner-diffusers" \
+ --dtype fp32
+"""
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ "--model_type", type=str, default=None
+) # hunyuanimage2.1, hunyuanimage2.1-distilled, hunyuanimage-refiner
+parser.add_argument("--transformer_checkpoint_path", default=None, type=str) # ckpts/dit/hunyuanimage2.1.safetensors
+parser.add_argument("--vae_checkpoint_path", default=None, type=str) # ckpts/vae/vae_2_1/pytorch_model.ckpt
+parser.add_argument("--output_path", type=str)
+parser.add_argument("--dtype", type=str, default="fp32")
+
+args = parser.parse_args()
+dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
+
+
+# copied from https://github.com/Tencent-Hunyuan/HunyuanImage-2.1/hyimage/models/hunyuan/modules/hunyuanimage_dit.py#L21
+def convert_hunyuan_dict_for_tensor_parallel(state_dict):
+ """
+ Convert a Hunyuan model state dict to be compatible with tensor parallel architectures.
+
+ Args:
+ state_dict: Original state dict
+
+ Returns:
+ new_dict: Converted state dict
+ """
+ new_dict = {}
+ for k, w in state_dict.items():
+ if k.startswith("double_blocks") and "attn_qkv.weight" in k:
+ hidden_size = w.shape[1]
+ k1 = k.replace("attn_qkv.weight", "attn_q.weight")
+ w1 = w[:hidden_size, :]
+ new_dict[k1] = w1
+ k2 = k.replace("attn_qkv.weight", "attn_k.weight")
+ w2 = w[hidden_size : 2 * hidden_size, :]
+ new_dict[k2] = w2
+ k3 = k.replace("attn_qkv.weight", "attn_v.weight")
+ w3 = w[-hidden_size:, :]
+ new_dict[k3] = w3
+ elif k.startswith("double_blocks") and "attn_qkv.bias" in k:
+ hidden_size = w.shape[0] // 3
+ k1 = k.replace("attn_qkv.bias", "attn_q.bias")
+ w1 = w[:hidden_size]
+ new_dict[k1] = w1
+ k2 = k.replace("attn_qkv.bias", "attn_k.bias")
+ w2 = w[hidden_size : 2 * hidden_size]
+ new_dict[k2] = w2
+ k3 = k.replace("attn_qkv.bias", "attn_v.bias")
+ w3 = w[-hidden_size:]
+ new_dict[k3] = w3
+ elif k.startswith("single_blocks") and "linear1" in k:
+ hidden_size = state_dict[k.replace("linear1", "linear2")].shape[0]
+ k1 = k.replace("linear1", "linear1_q")
+ w1 = w[:hidden_size]
+ new_dict[k1] = w1
+ k2 = k.replace("linear1", "linear1_k")
+ w2 = w[hidden_size : 2 * hidden_size]
+ new_dict[k2] = w2
+ k3 = k.replace("linear1", "linear1_v")
+ w3 = w[2 * hidden_size : 3 * hidden_size]
+ new_dict[k3] = w3
+ k4 = k.replace("linear1", "linear1_mlp")
+ w4 = w[3 * hidden_size :]
+ new_dict[k4] = w4
+ elif k.startswith("single_blocks") and "linear2" in k:
+ k1 = k.replace("linear2", "linear2.fc")
+ new_dict[k1] = w
+ else:
+ new_dict[k] = w
+ return new_dict
+
+
+def load_original_vae_checkpoint(args):
+ # "ckpts/vae/vae_2_1/pytorch_model.ckpt"
+ state_dict = torch.load(args.vae_checkpoint_path)
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ vae_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith("vae."):
+ vae_state_dict[k.replace("vae.", "")] = v
+
+ for k, v in vae_state_dict.items():
+ if "weight" in k:
+ if len(v.shape) == 5 and v.shape[2] == 1:
+ vae_state_dict[k] = v.squeeze(2)
+ else:
+ vae_state_dict[k] = v
+ else:
+ vae_state_dict[k] = v
+ return vae_state_dict
+
+
+def load_original_refiner_vae_checkpoint(args):
+ # "ckpts/vae/vae_refiner/pytorch_model.pt"
+ state_dict = torch.load(args.vae_checkpoint_path)
+
+ if "state_dict" in state_dict:
+ state_dict = state_dict["state_dict"]
+ vae_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith("vae."):
+ vae_state_dict[k.replace("vae.", "")] = v
+ return vae_state_dict
+
+
+def load_original_transformer_checkpoint(args):
+ # ckpts/dit/hunyuanimage-refiner.safetensors"
+ # ckpts/dit/hunyuanimage2.1.safetensors"
+ state_dict = {}
+ with safe_open(args.transformer_checkpoint_path, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ state_dict[key] = f.get_tensor(key)
+ if args.model_type == "hunyuanimage-2.1":
+ state_dict = convert_hunyuan_dict_for_tensor_parallel(state_dict)
+ return state_dict
+
+
+def convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_state_dict, use_byt5=True, guidance_distilled=False, use_meanflow=False
+):
+ converted_state_dict = {}
+
+ # 1. byt5_in -> context_embedder_2
+ if use_byt5:
+ converted_state_dict["context_embedder_2.norm.weight"] = original_state_dict.pop("byt5_in.layernorm.weight")
+ converted_state_dict["context_embedder_2.norm.bias"] = original_state_dict.pop("byt5_in.layernorm.bias")
+ converted_state_dict["context_embedder_2.linear_1.weight"] = original_state_dict.pop("byt5_in.fc1.weight")
+ converted_state_dict["context_embedder_2.linear_1.bias"] = original_state_dict.pop("byt5_in.fc1.bias")
+ converted_state_dict["context_embedder_2.linear_2.weight"] = original_state_dict.pop("byt5_in.fc2.weight")
+ converted_state_dict["context_embedder_2.linear_2.bias"] = original_state_dict.pop("byt5_in.fc2.bias")
+ converted_state_dict["context_embedder_2.linear_3.weight"] = original_state_dict.pop("byt5_in.fc3.weight")
+ converted_state_dict["context_embedder_2.linear_3.bias"] = original_state_dict.pop("byt5_in.fc3.bias")
+
+ # 2. img_in -> x_embedder
+ converted_state_dict["x_embedder.proj.weight"] = original_state_dict.pop("img_in.proj.weight")
+ converted_state_dict["x_embedder.proj.bias"] = original_state_dict.pop("img_in.proj.bias")
+
+ # 3. txt_in -> context_embedder (complex mapping)
+ # txt_in.input_embedder -> context_embedder.proj_in
+ converted_state_dict["context_embedder.proj_in.weight"] = original_state_dict.pop("txt_in.input_embedder.weight")
+ converted_state_dict["context_embedder.proj_in.bias"] = original_state_dict.pop("txt_in.input_embedder.bias")
+
+ # txt_in.t_embedder -> context_embedder.time_text_embed.timestep_embedder
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.0.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.0.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.weight"] = (
+ original_state_dict.pop("txt_in.t_embedder.mlp.2.weight")
+ )
+ converted_state_dict["context_embedder.time_text_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.t_embedder.mlp.2.bias"
+ )
+
+ # txt_in.c_embedder -> context_embedder.time_text_embed.text_embedder
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_1.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_1.bias"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.weight"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.weight"
+ )
+ converted_state_dict["context_embedder.time_text_embed.text_embedder.linear_2.bias"] = original_state_dict.pop(
+ "txt_in.c_embedder.linear_2.bias"
+ )
+
+ # txt_in.individual_token_refiner -> context_embedder.token_refiner
+ for i in range(2): # 2 refiner blocks
+ block_prefix = f"context_embedder.token_refiner.refiner_blocks.{i}."
+ # norm1
+ converted_state_dict[f"{block_prefix}norm1.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm1.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm1.bias"
+ )
+ # norm2
+ converted_state_dict[f"{block_prefix}norm2.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm2.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm2.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.norm2.bias"
+ )
+
+ # Split QKV
+ qkv_weight = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.weight")
+ qkv_bias = original_state_dict.pop(f"txt_in.individual_token_refiner.blocks.{i}.self_attn_qkv.bias")
+ q_weight, k_weight, v_weight = torch.chunk(qkv_weight, 3, dim=0)
+ q_bias, k_bias, v_bias = torch.chunk(qkv_bias, 3, dim=0)
+
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = q_weight
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = q_bias
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = k_weight
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = k_bias
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = v_weight
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = v_bias
+
+ # attn projection
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.self_attn_proj.bias"
+ )
+
+ # MLP
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.mlp.fc2.bias"
+ )
+
+ # norm_out
+ converted_state_dict[f"{block_prefix}norm_out.linear.weight"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm_out.linear.bias"] = original_state_dict.pop(
+ f"txt_in.individual_token_refiner.blocks.{i}.adaLN_modulation.1.bias"
+ )
+
+ # 4. time_in -> time_text_embed.timestep_embedder
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
+ "time_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
+ "time_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
+ "time_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
+ "time_in.mlp.2.bias"
+ )
+
+ # time_r_in -> time_guidance_embed.timestep_r_embedder
+ if use_meanflow:
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_1.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.weight"] = original_state_dict.pop(
+ "time_r_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.timestep_embedder_r.linear_2.bias"] = original_state_dict.pop(
+ "time_r_in.mlp.2.bias"
+ )
+
+ # guidance_in -> time_guidance_embed.guidance_embedder
+ if guidance_distilled:
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.weight"] = original_state_dict.pop(
+ "guidance_in.mlp.0.weight"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_1.bias"] = original_state_dict.pop(
+ "guidance_in.mlp.0.bias"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.weight"] = original_state_dict.pop(
+ "guidance_in.mlp.2.weight"
+ )
+ converted_state_dict["time_guidance_embed.guidance_embedder.linear_2.bias"] = original_state_dict.pop(
+ "guidance_in.mlp.2.bias"
+ )
+
+ # 5. double_blocks -> transformer_blocks
+ for i in range(20): # 20 double blocks
+ block_prefix = f"transformer_blocks.{i}."
+
+ # norm1 (img_mod)
+ converted_state_dict[f"{block_prefix}norm1.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mod.linear.bias"
+ )
+
+ # norm1_context (txt_mod)
+ converted_state_dict[f"{block_prefix}norm1_context.linear.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm1_context.linear.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mod.linear.bias"
+ )
+
+ # img attention
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_v.bias"
+ )
+
+ # img attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_k_norm.weight"
+ )
+
+ # img attention projection
+ converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_attn_proj.bias"
+ )
+
+ # img MLP
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.img_mlp.fc2.bias"
+ )
+
+ # txt attention (additional projections)
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_v.bias"
+ )
+
+ # txt attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_k_norm.weight"
+ )
+
+ # txt attention projection
+ converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_proj.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_attn_proj.bias"
+ )
+
+ # txt MLP (ff_context)
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc1.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc1.bias"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc2.weight"
+ )
+ converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = original_state_dict.pop(
+ f"double_blocks.{i}.txt_mlp.fc2.bias"
+ )
+
+ # 6. single_blocks -> single_transformer_blocks
+ for i in range(40): # 40 single blocks
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ # norm
+ converted_state_dict[f"{block_prefix}norm.linear.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.linear.weight"
+ )
+ converted_state_dict[f"{block_prefix}norm.linear.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.modulation.linear.bias"
+ )
+
+ # attention Q, K, V
+ converted_state_dict[f"{block_prefix}attn.to_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_q.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_q.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_k.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_k.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_k.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_v.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_v.bias"
+ )
+
+ # attention norms
+ converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.q_norm.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.k_norm.weight"
+ )
+
+ # MLP projection
+ converted_state_dict[f"{block_prefix}proj_mlp.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_mlp.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_mlp.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear1_mlp.bias"
+ )
+
+ # output projection
+ converted_state_dict[f"{block_prefix}proj_out.weight"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.fc.weight"
+ )
+ converted_state_dict[f"{block_prefix}proj_out.bias"] = original_state_dict.pop(
+ f"single_blocks.{i}.linear2.fc.bias"
+ )
+
+ # 7. final_layer -> norm_out + proj_out
+ converted_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
+ shift_w, scale_w = original_state_dict.pop("final_layer.adaLN_modulation.1.weight").chunk(2, dim=0)
+ shift_b, scale_b = original_state_dict.pop("final_layer.adaLN_modulation.1.bias").chunk(2, dim=0)
+ converted_state_dict["norm_out.linear.weight"] = torch.cat([scale_w, shift_w], dim=0)
+ converted_state_dict["norm_out.linear.bias"] = torch.cat([scale_b, shift_b], dim=0)
+
+ return converted_state_dict, original_state_dict
+
+
+def convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+):
+ """Convert original VAE state dict to Diffusers format."""
+ converted = {}
+
+ # 1. Encoder
+ # 1.1 conv_in
+ converted["encoder.conv_in.weight"] = original_state_dict.pop("encoder.conv_in.weight")
+ converted["encoder.conv_in.bias"] = original_state_dict.pop("encoder.conv_in.bias")
+
+ # 1.2 down blocks
+ diffusers_block_idx = 0
+
+ for block_index in range(len(block_out_channels)):
+ for resnet_block_index in range(layers_per_block):
+ orig_prefix = f"encoder.down.{block_index}.block.{resnet_block_index}"
+ diff_prefix = f"encoder.down_blocks.{diffusers_block_idx}"
+
+ # resnet blocks
+ converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
+ converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
+ converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
+ converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
+ converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
+ converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
+ converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
+ converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
+
+ diffusers_block_idx += 1
+
+ # downsample blocks
+ if f"encoder.down.{block_index}.downsample.conv.weight" in original_state_dict:
+ converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
+ f"encoder.down.{block_index}.downsample.conv.weight"
+ )
+ converted[f"encoder.down_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
+ f"encoder.down.{block_index}.downsample.conv.bias"
+ )
+ diffusers_block_idx += 1
+
+ # 1.3 mid block
+ converted["encoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("encoder.mid.block_1.norm1.weight")
+ converted["encoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("encoder.mid.block_1.norm1.bias")
+ converted["encoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("encoder.mid.block_1.conv1.weight")
+ converted["encoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("encoder.mid.block_1.conv1.bias")
+ converted["encoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("encoder.mid.block_1.norm2.weight")
+ converted["encoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("encoder.mid.block_1.norm2.bias")
+ converted["encoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("encoder.mid.block_1.conv2.weight")
+ converted["encoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("encoder.mid.block_1.conv2.bias")
+
+ converted["encoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("encoder.mid.block_2.norm1.weight")
+ converted["encoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("encoder.mid.block_2.norm1.bias")
+ converted["encoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("encoder.mid.block_2.conv1.weight")
+ converted["encoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("encoder.mid.block_2.conv1.bias")
+ converted["encoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("encoder.mid.block_2.norm2.weight")
+ converted["encoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("encoder.mid.block_2.norm2.bias")
+ converted["encoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("encoder.mid.block_2.conv2.weight")
+ converted["encoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("encoder.mid.block_2.conv2.bias")
+
+ converted["encoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("encoder.mid.attn_1.norm.weight")
+ converted["encoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("encoder.mid.attn_1.norm.bias")
+ converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
+ converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
+ converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
+ converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
+ converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
+ converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
+ converted["encoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.weight"
+ )
+ converted["encoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("encoder.mid.attn_1.proj_out.bias")
+
+ # 1.4 encoder output
+ converted["encoder.norm_out.weight"] = original_state_dict.pop("encoder.norm_out.weight")
+ converted["encoder.norm_out.bias"] = original_state_dict.pop("encoder.norm_out.bias")
+ converted["encoder.conv_out.weight"] = original_state_dict.pop("encoder.conv_out.weight")
+ converted["encoder.conv_out.bias"] = original_state_dict.pop("encoder.conv_out.bias")
+
+ # 2. Decoder
+ # 2.1 conv_in
+ converted["decoder.conv_in.weight"] = original_state_dict.pop("decoder.conv_in.weight")
+ converted["decoder.conv_in.bias"] = original_state_dict.pop("decoder.conv_in.bias")
+
+ # 2.2 mid block
+ converted["decoder.mid_block.resnets.0.norm1.weight"] = original_state_dict.pop("decoder.mid.block_1.norm1.weight")
+ converted["decoder.mid_block.resnets.0.norm1.bias"] = original_state_dict.pop("decoder.mid.block_1.norm1.bias")
+ converted["decoder.mid_block.resnets.0.conv1.weight"] = original_state_dict.pop("decoder.mid.block_1.conv1.weight")
+ converted["decoder.mid_block.resnets.0.conv1.bias"] = original_state_dict.pop("decoder.mid.block_1.conv1.bias")
+ converted["decoder.mid_block.resnets.0.norm2.weight"] = original_state_dict.pop("decoder.mid.block_1.norm2.weight")
+ converted["decoder.mid_block.resnets.0.norm2.bias"] = original_state_dict.pop("decoder.mid.block_1.norm2.bias")
+ converted["decoder.mid_block.resnets.0.conv2.weight"] = original_state_dict.pop("decoder.mid.block_1.conv2.weight")
+ converted["decoder.mid_block.resnets.0.conv2.bias"] = original_state_dict.pop("decoder.mid.block_1.conv2.bias")
+
+ converted["decoder.mid_block.resnets.1.norm1.weight"] = original_state_dict.pop("decoder.mid.block_2.norm1.weight")
+ converted["decoder.mid_block.resnets.1.norm1.bias"] = original_state_dict.pop("decoder.mid.block_2.norm1.bias")
+ converted["decoder.mid_block.resnets.1.conv1.weight"] = original_state_dict.pop("decoder.mid.block_2.conv1.weight")
+ converted["decoder.mid_block.resnets.1.conv1.bias"] = original_state_dict.pop("decoder.mid.block_2.conv1.bias")
+ converted["decoder.mid_block.resnets.1.norm2.weight"] = original_state_dict.pop("decoder.mid.block_2.norm2.weight")
+ converted["decoder.mid_block.resnets.1.norm2.bias"] = original_state_dict.pop("decoder.mid.block_2.norm2.bias")
+ converted["decoder.mid_block.resnets.1.conv2.weight"] = original_state_dict.pop("decoder.mid.block_2.conv2.weight")
+ converted["decoder.mid_block.resnets.1.conv2.bias"] = original_state_dict.pop("decoder.mid.block_2.conv2.bias")
+
+ converted["decoder.mid_block.attentions.0.norm.weight"] = original_state_dict.pop("decoder.mid.attn_1.norm.weight")
+ converted["decoder.mid_block.attentions.0.norm.bias"] = original_state_dict.pop("decoder.mid.attn_1.norm.bias")
+ converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
+ converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
+ converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
+ converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
+ converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
+ converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
+ converted["decoder.mid_block.attentions.0.proj.weight"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.weight"
+ )
+ converted["decoder.mid_block.attentions.0.proj.bias"] = original_state_dict.pop("decoder.mid.attn_1.proj_out.bias")
+
+ # 2.3 up blocks
+ diffusers_block_idx = 0
+ for up_block_index in range(len(block_out_channels)):
+ # resnet blocks
+ for resnet_block_index in range(layers_per_block + 1):
+ orig_prefix = f"decoder.up.{up_block_index}.block.{resnet_block_index}"
+ diff_prefix = f"decoder.up_blocks.{diffusers_block_idx}"
+
+ converted[f"{diff_prefix}.norm1.weight"] = original_state_dict.pop(f"{orig_prefix}.norm1.weight")
+ converted[f"{diff_prefix}.norm1.bias"] = original_state_dict.pop(f"{orig_prefix}.norm1.bias")
+ converted[f"{diff_prefix}.conv1.weight"] = original_state_dict.pop(f"{orig_prefix}.conv1.weight")
+ converted[f"{diff_prefix}.conv1.bias"] = original_state_dict.pop(f"{orig_prefix}.conv1.bias")
+ converted[f"{diff_prefix}.norm2.weight"] = original_state_dict.pop(f"{orig_prefix}.norm2.weight")
+ converted[f"{diff_prefix}.norm2.bias"] = original_state_dict.pop(f"{orig_prefix}.norm2.bias")
+ converted[f"{diff_prefix}.conv2.weight"] = original_state_dict.pop(f"{orig_prefix}.conv2.weight")
+ converted[f"{diff_prefix}.conv2.bias"] = original_state_dict.pop(f"{orig_prefix}.conv2.bias")
+
+ diffusers_block_idx += 1
+
+ # upsample blocks
+ if f"decoder.up.{up_block_index}.upsample.conv.weight" in original_state_dict:
+ converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.weight"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.weight"
+ )
+ converted[f"decoder.up_blocks.{diffusers_block_idx}.conv.bias"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.bias"
+ )
+ diffusers_block_idx += 1
+
+ # 2.4 decoder output
+ converted["decoder.norm_out.weight"] = original_state_dict.pop("decoder.norm_out.weight")
+ converted["decoder.norm_out.bias"] = original_state_dict.pop("decoder.norm_out.bias")
+ converted["decoder.conv_out.weight"] = original_state_dict.pop("decoder.conv_out.weight")
+ converted["decoder.conv_out.bias"] = original_state_dict.pop("decoder.conv_out.bias")
+
+ return converted, original_state_dict
+
+
+def convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
+ original_state_dict, block_out_channels=[128, 256, 512, 1024, 1024], layers_per_block=2
+):
+ converted = {}
+
+ # 1. Encoder
+ # 1.1 conv_in
+ converted["encoder.conv_in.conv.weight"] = original_state_dict.pop("encoder.conv_in.conv.weight")
+ converted["encoder.conv_in.conv.bias"] = original_state_dict.pop("encoder.conv_in.conv.bias")
+
+ # 1.2 Down blocks
+ for down_block_index in range(len(block_out_channels)): # 0 to 4
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block): # 0 to 1
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(
+ f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.weight"
+ )
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Downsample (if exists)
+ if f"encoder.down.{down_block_index}.downsample.conv.conv.weight" in original_state_dict:
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.weight"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.weight")
+ )
+ converted[f"encoder.down_blocks.{down_block_index}.downsamplers.0.conv.conv.bias"] = (
+ original_state_dict.pop(f"encoder.down.{down_block_index}.downsample.conv.conv.bias")
+ )
+
+ # 1.3 Mid block
+ converted["encoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm1.gamma")
+ converted["encoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_1.norm2.gamma")
+ converted["encoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["encoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm1.gamma")
+ converted["encoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["encoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("encoder.mid.block_2.norm2.gamma")
+ converted["encoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["encoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "encoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Attention block
+ converted["encoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("encoder.mid.attn_1.norm.gamma")
+ converted["encoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("encoder.mid.attn_1.q.weight")
+ converted["encoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("encoder.mid.attn_1.q.bias")
+ converted["encoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("encoder.mid.attn_1.k.weight")
+ converted["encoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("encoder.mid.attn_1.k.bias")
+ converted["encoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("encoder.mid.attn_1.v.weight")
+ converted["encoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("encoder.mid.attn_1.v.bias")
+ converted["encoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.weight"
+ )
+ converted["encoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "encoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 1.4 Encoder output
+ converted["encoder.norm_out.gamma"] = original_state_dict.pop("encoder.norm_out.gamma")
+ converted["encoder.conv_out.conv.weight"] = original_state_dict.pop("encoder.conv_out.conv.weight")
+ converted["encoder.conv_out.conv.bias"] = original_state_dict.pop("encoder.conv_out.conv.bias")
+
+ # 2. Decoder
+ # 2.1 conv_in
+ converted["decoder.conv_in.conv.weight"] = original_state_dict.pop("decoder.conv_in.conv.weight")
+ converted["decoder.conv_in.conv.bias"] = original_state_dict.pop("decoder.conv_in.conv.bias")
+
+ # 2.2 Mid block
+ converted["decoder.mid_block.resnets.0.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm1.gamma")
+ converted["decoder.mid_block.resnets.0.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.0.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_1.norm2.gamma")
+ converted["decoder.mid_block.resnets.0.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.0.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_1.conv2.conv.bias"
+ )
+
+ converted["decoder.mid_block.resnets.1.norm1.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm1.gamma")
+ converted["decoder.mid_block.resnets.1.conv1.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv1.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv1.conv.bias"
+ )
+ converted["decoder.mid_block.resnets.1.norm2.gamma"] = original_state_dict.pop("decoder.mid.block_2.norm2.gamma")
+ converted["decoder.mid_block.resnets.1.conv2.conv.weight"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.weight"
+ )
+ converted["decoder.mid_block.resnets.1.conv2.conv.bias"] = original_state_dict.pop(
+ "decoder.mid.block_2.conv2.conv.bias"
+ )
+
+ # Decoder attention block
+ converted["decoder.mid_block.attentions.0.norm.gamma"] = original_state_dict.pop("decoder.mid.attn_1.norm.gamma")
+ converted["decoder.mid_block.attentions.0.to_q.weight"] = original_state_dict.pop("decoder.mid.attn_1.q.weight")
+ converted["decoder.mid_block.attentions.0.to_q.bias"] = original_state_dict.pop("decoder.mid.attn_1.q.bias")
+ converted["decoder.mid_block.attentions.0.to_k.weight"] = original_state_dict.pop("decoder.mid.attn_1.k.weight")
+ converted["decoder.mid_block.attentions.0.to_k.bias"] = original_state_dict.pop("decoder.mid.attn_1.k.bias")
+ converted["decoder.mid_block.attentions.0.to_v.weight"] = original_state_dict.pop("decoder.mid.attn_1.v.weight")
+ converted["decoder.mid_block.attentions.0.to_v.bias"] = original_state_dict.pop("decoder.mid.attn_1.v.bias")
+ converted["decoder.mid_block.attentions.0.proj_out.weight"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.weight"
+ )
+ converted["decoder.mid_block.attentions.0.proj_out.bias"] = original_state_dict.pop(
+ "decoder.mid.attn_1.proj_out.bias"
+ )
+
+ # 2.3 Up blocks
+ for up_block_index in range(len(block_out_channels)): # 0 to 5
+ # ResNet blocks
+ for resnet_block_index in range(layers_per_block + 1): # 0 to 2 (decoder has 3 resnets per level)
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm1.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm1.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv1.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv1.conv.bias")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.norm2.gamma"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.norm2.gamma")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.weight"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.weight")
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.resnets.{resnet_block_index}.conv2.conv.bias"] = (
+ original_state_dict.pop(f"decoder.up.{up_block_index}.block.{resnet_block_index}.conv2.conv.bias")
+ )
+
+ # Upsample (if exists)
+ if f"decoder.up.{up_block_index}.upsample.conv.conv.weight" in original_state_dict:
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.weight"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.weight"
+ )
+ converted[f"decoder.up_blocks.{up_block_index}.upsamplers.0.conv.conv.bias"] = original_state_dict.pop(
+ f"decoder.up.{up_block_index}.upsample.conv.conv.bias"
+ )
+
+ # 2.4 Decoder output
+ converted["decoder.norm_out.gamma"] = original_state_dict.pop("decoder.norm_out.gamma")
+ converted["decoder.conv_out.conv.weight"] = original_state_dict.pop("decoder.conv_out.conv.weight")
+ converted["decoder.conv_out.conv.bias"] = original_state_dict.pop("decoder.conv_out.conv.bias")
+
+ return converted, original_state_dict
+
+
+def main(args):
+ if args.model_type == "hunyuanimage2.1":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 64,
+ "out_channels": 64,
+ "num_attention_heads": 28,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "patch_size": (1, 1),
+ "qk_norm": "rms_norm",
+ "guidance_embeds": False,
+ "text_embed_dim": 3584,
+ "text_embed_2_dim": 1472,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (64, 64),
+ }
+
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=True, guidance_distilled=False
+ )
+ )
+
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae_config_diffusers = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 64,
+ "block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "layers_per_block": 2,
+ "spatial_compression_ratio": 32,
+ "sample_size": 384,
+ "scaling_factor": 0.75289,
+ "downsample_match_channel": True,
+ "upsample_match_channel": True,
+ }
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+ elif args.model_type == "hunyuanimage2.1-distilled":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 64,
+ "out_channels": 64,
+ "num_attention_heads": 28,
+ "attention_head_dim": 128,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "num_refiner_layers": 2,
+ "patch_size": (1, 1),
+ "qk_norm": "rms_norm",
+ "guidance_embeds": True,
+ "text_embed_dim": 3584,
+ "text_embed_2_dim": 1472,
+ "rope_theta": 256.0,
+ "rope_axes_dim": (64, 64),
+ "use_meanflow": True,
+ }
+
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=True, guidance_distilled=True, use_meanflow=True
+ )
+ )
+
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae_config_diffusers = {
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 64,
+ "block_out_channels": [128, 256, 512, 512, 1024, 1024],
+ "layers_per_block": 2,
+ "spatial_compression_ratio": 32,
+ "sample_size": 384,
+ "scaling_factor": 0.75289,
+ "downsample_match_channel": True,
+ "upsample_match_channel": True,
+ }
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_vae_checkpoint_to_diffusers(
+ original_vae_state_dict, block_out_channels=[128, 256, 512, 512, 1024, 1024], layers_per_block=2
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ vae = AutoencoderKLHunyuanImage(**vae_config_diffusers)
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+
+ if missing_keys:
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+ elif args.model_type == "hunyuanimage-refiner":
+ original_transformer_state_dict = load_original_transformer_checkpoint(args)
+ original_vae_state_dict = load_original_refiner_vae_checkpoint(args)
+
+ transformer_config = {
+ "in_channels": 128,
+ "out_channels": 64,
+ "num_layers": 20,
+ "num_single_layers": 40,
+ "rope_axes_dim": [16, 56, 56],
+ "num_attention_heads": 26,
+ "attention_head_dim": 128,
+ "mlp_ratio": 4,
+ "patch_size": (1, 1, 1),
+ "text_embed_dim": 3584,
+ "guidance_embeds": True,
+ }
+ converted_transformer_state_dict, original_transformer_state_dict = (
+ convert_hunyuan_image_transformer_checkpoint_to_diffusers(
+ original_transformer_state_dict, use_byt5=False, guidance_distilled=True
+ )
+ )
+ if original_transformer_state_dict:
+ logger.warning(
+ f"Unused {len(original_transformer_state_dict)} original keys for transformer: {list(original_transformer_state_dict.keys())}"
+ )
+
+ transformer = HunyuanImageTransformer2DModel(**transformer_config)
+ missing_keys, unexpected_key = transformer.load_state_dict(converted_transformer_state_dict, strict=True)
+ if missing_keys:
+ logger.warning(f"Missing keys for transformer: {missing_keys}")
+ if unexpected_key:
+ logger.warning(f"Unexpected keys for transformer: {unexpected_key}")
+
+ transformer.to(dtype).save_pretrained(f"{args.output_path}/transformer")
+
+ vae = AutoencoderKLHunyuanImageRefiner()
+
+ converted_vae_state_dict, original_vae_state_dict = convert_hunyuan_image_refiner_vae_checkpoint_to_diffusers(
+ original_vae_state_dict
+ )
+ if original_vae_state_dict:
+ logger.warning(
+ f"Unused {len(original_vae_state_dict)} original keys for vae: {list(original_vae_state_dict.keys())}"
+ )
+
+ missing_keys, unexpected_key = vae.load_state_dict(converted_vae_state_dict, strict=True)
+ logger.warning(f"Missing keys for vae: {missing_keys}")
+ logger.warning(f"Unexpected keys for vae: {unexpected_key}")
+
+ vae.to(dtype).save_pretrained(f"{args.output_path}/vae")
+
+
+if __name__ == "__main__":
+ main(args)
diff --git a/scripts/convert_ltx_to_diffusers.py b/scripts/convert_ltx_to_diffusers.py
index 256312cc72..19e5602039 100644
--- a/scripts/convert_ltx_to_diffusers.py
+++ b/scripts/convert_ltx_to_diffusers.py
@@ -369,6 +369,15 @@ def get_spatial_latent_upsampler_config(version: str) -> Dict[str, Any]:
"spatial_upsample": True,
"temporal_upsample": False,
}
+ elif version == "0.9.8":
+ config = {
+ "in_channels": 128,
+ "mid_channels": 512,
+ "num_blocks_per_stage": 4,
+ "dims": 3,
+ "spatial_upsample": True,
+ "temporal_upsample": False,
+ }
else:
raise ValueError(f"Unsupported version: {version}")
return config
@@ -402,7 +411,7 @@ def get_args():
"--version",
type=str,
default="0.9.0",
- choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7"],
+ choices=["0.9.0", "0.9.1", "0.9.5", "0.9.7", "0.9.8"],
help="Version of the LTX model",
)
return parser.parse_args()
diff --git a/scripts/convert_prx_to_diffusers.py b/scripts/convert_prx_to_diffusers.py
new file mode 100644
index 0000000000..d9bde2f34d
--- /dev/null
+++ b/scripts/convert_prx_to_diffusers.py
@@ -0,0 +1,345 @@
+#!/usr/bin/env python3
+"""
+Script to convert PRX checkpoint from original codebase to diffusers format.
+"""
+
+import argparse
+import json
+import os
+import sys
+from dataclasses import asdict, dataclass
+from typing import Dict, Tuple
+
+import torch
+from safetensors.torch import save_file
+
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.prx import PRXPipeline
+
+
+DEFAULT_RESOLUTION = 512
+
+
+@dataclass(frozen=True)
+class PRXBase:
+ context_in_dim: int = 2304
+ hidden_size: int = 1792
+ mlp_ratio: float = 3.5
+ num_heads: int = 28
+ depth: int = 16
+ axes_dim: Tuple[int, int] = (32, 32)
+ theta: int = 10_000
+ time_factor: float = 1000.0
+ time_max_period: int = 10_000
+
+
+@dataclass(frozen=True)
+class PRXFlux(PRXBase):
+ in_channels: int = 16
+ patch_size: int = 2
+
+
+@dataclass(frozen=True)
+class PRXDCAE(PRXBase):
+ in_channels: int = 32
+ patch_size: int = 1
+
+
+def build_config(vae_type: str) -> Tuple[dict, int]:
+ if vae_type == "flux":
+ cfg = PRXFlux()
+ elif vae_type == "dc-ae":
+ cfg = PRXDCAE()
+ else:
+ raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
+
+ config_dict = asdict(cfg)
+ config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
+ return config_dict
+
+
+def create_parameter_mapping(depth: int) -> dict:
+ """Create mapping from old parameter names to new diffusers names."""
+
+ # Key mappings for structural changes
+ mapping = {}
+
+ # Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
+ for i in range(depth):
+ # QKV projections moved to attention module
+ mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
+ mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
+
+ # QK norm moved to attention module and renamed to match Attention's qk_norm structure
+ mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
+ mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
+ mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
+ mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
+
+ # K norm for text tokens moved to attention module
+ mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
+ mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
+
+ # Attention output projection
+ mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
+
+ return mapping
+
+
+def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
+ """Convert old checkpoint parameters to new diffusers format."""
+
+ print("Converting checkpoint parameters...")
+
+ mapping = create_parameter_mapping(depth)
+ converted_state_dict = {}
+
+ for key, value in old_state_dict.items():
+ new_key = key
+
+ # Apply specific mappings if needed
+ if key in mapping:
+ new_key = mapping[key]
+ print(f" Mapped: {key} -> {new_key}")
+
+ converted_state_dict[new_key] = value
+
+ print(f"✓ Converted {len(converted_state_dict)} parameters")
+ return converted_state_dict
+
+
+def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
+ """Create and load PRXTransformer2DModel from old checkpoint."""
+
+ print(f"Loading checkpoint from: {checkpoint_path}")
+
+ # Load old checkpoint
+ if not os.path.exists(checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
+
+ old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
+
+ # Handle different checkpoint formats
+ if isinstance(old_checkpoint, dict):
+ if "model" in old_checkpoint:
+ state_dict = old_checkpoint["model"]
+ elif "state_dict" in old_checkpoint:
+ state_dict = old_checkpoint["state_dict"]
+ else:
+ state_dict = old_checkpoint
+ else:
+ state_dict = old_checkpoint
+
+ print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
+
+ # Convert parameter names if needed
+ model_depth = int(config.get("depth", 16))
+ converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
+
+ # Create transformer with config
+ print("Creating PRXTransformer2DModel...")
+ transformer = PRXTransformer2DModel(**config)
+
+ # Load state dict
+ print("Loading converted parameters...")
+ missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
+
+ if missing_keys:
+ print(f"⚠ Missing keys: {missing_keys}")
+ if unexpected_keys:
+ print(f"⚠ Unexpected keys: {unexpected_keys}")
+
+ if not missing_keys and not unexpected_keys:
+ print("✓ All parameters loaded successfully!")
+
+ return transformer
+
+
+def create_scheduler_config(output_path: str, shift: float):
+ """Create FlowMatchEulerDiscreteScheduler config."""
+
+ scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
+
+ scheduler_path = os.path.join(output_path, "scheduler")
+ os.makedirs(scheduler_path, exist_ok=True)
+
+ with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
+ json.dump(scheduler_config, f, indent=2)
+
+ print("✓ Created scheduler config")
+
+
+def download_and_save_vae(vae_type: str, output_path: str):
+ """Download and save VAE to local directory."""
+ from diffusers import AutoencoderDC, AutoencoderKL
+
+ vae_path = os.path.join(output_path, "vae")
+ os.makedirs(vae_path, exist_ok=True)
+
+ if vae_type == "flux":
+ print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
+ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
+ else: # dc-ae
+ print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
+ vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
+
+ vae.save_pretrained(vae_path)
+ print(f"✓ Saved VAE to {vae_path}")
+
+
+def download_and_save_text_encoder(output_path: str):
+ """Download and save T5Gemma text encoder and tokenizer."""
+ from transformers import GemmaTokenizerFast
+ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
+
+ text_encoder_path = os.path.join(output_path, "text_encoder")
+ tokenizer_path = os.path.join(output_path, "tokenizer")
+ os.makedirs(text_encoder_path, exist_ok=True)
+ os.makedirs(tokenizer_path, exist_ok=True)
+
+ print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
+ t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
+
+ # Extract and save only the encoder
+ t5gemma_encoder = t5gemma_model.encoder
+ t5gemma_encoder.save_pretrained(text_encoder_path)
+ print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
+
+ print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
+ tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
+ tokenizer.model_max_length = 256
+ tokenizer.save_pretrained(tokenizer_path)
+ print(f"✓ Saved tokenizer to {tokenizer_path}")
+
+
+def create_model_index(vae_type: str, default_image_size: int, output_path: str):
+ """Create model_index.json for the pipeline."""
+
+ if vae_type == "flux":
+ vae_class = "AutoencoderKL"
+ else: # dc-ae
+ vae_class = "AutoencoderDC"
+
+ model_index = {
+ "_class_name": "PRXPipeline",
+ "_diffusers_version": "0.31.0.dev0",
+ "_name_or_path": os.path.basename(output_path),
+ "default_sample_size": default_image_size,
+ "scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
+ "text_encoder": ["prx", "T5GemmaEncoder"],
+ "tokenizer": ["transformers", "GemmaTokenizerFast"],
+ "transformer": ["diffusers", "PRXTransformer2DModel"],
+ "vae": ["diffusers", vae_class],
+ }
+
+ model_index_path = os.path.join(output_path, "model_index.json")
+ with open(model_index_path, "w") as f:
+ json.dump(model_index, f, indent=2)
+
+
+def main(args):
+ # Validate inputs
+ if not os.path.exists(args.checkpoint_path):
+ raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
+
+ config = build_config(args.vae_type)
+
+ # Create output directory
+ os.makedirs(args.output_path, exist_ok=True)
+ print(f"✓ Output directory: {args.output_path}")
+
+ # Create transformer from checkpoint
+ transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
+
+ # Save transformer
+ transformer_path = os.path.join(args.output_path, "transformer")
+ os.makedirs(transformer_path, exist_ok=True)
+
+ # Save config
+ with open(os.path.join(transformer_path, "config.json"), "w") as f:
+ json.dump(config, f, indent=2)
+
+ # Save model weights as safetensors
+ state_dict = transformer.state_dict()
+ save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
+ print(f"✓ Saved transformer to {transformer_path}")
+
+ # Create scheduler config
+ create_scheduler_config(args.output_path, args.shift)
+
+ download_and_save_vae(args.vae_type, args.output_path)
+ download_and_save_text_encoder(args.output_path)
+
+ # Create model_index.json
+ create_model_index(args.vae_type, args.resolution, args.output_path)
+
+ # Verify the pipeline can be loaded
+ try:
+ pipeline = PRXPipeline.from_pretrained(args.output_path)
+ print("Pipeline loaded successfully!")
+ print(f"Transformer: {type(pipeline.transformer).__name__}")
+ print(f"VAE: {type(pipeline.vae).__name__}")
+ print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
+ print(f"Scheduler: {type(pipeline.scheduler).__name__}")
+
+ # Display model info
+ num_params = sum(p.numel() for p in pipeline.transformer.parameters())
+ print(f"✓ Transformer parameters: {num_params:,}")
+
+ except Exception as e:
+ print(f"Pipeline verification failed: {e}")
+ return False
+
+ print("Conversion completed successfully!")
+ print(f"Converted pipeline saved to: {args.output_path}")
+ print(f"VAE type: {args.vae_type}")
+
+ return True
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
+
+ parser.add_argument(
+ "--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
+ )
+
+ parser.add_argument(
+ "--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
+ )
+
+ parser.add_argument(
+ "--vae_type",
+ type=str,
+ choices=["flux", "dc-ae"],
+ required=True,
+ help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
+ )
+
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ choices=[256, 512, 1024],
+ default=DEFAULT_RESOLUTION,
+ help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
+ )
+
+ parser.add_argument(
+ "--shift",
+ type=float,
+ default=3.0,
+ help="Shift for the scheduler",
+ )
+
+ args = parser.parse_args()
+
+ try:
+ success = main(args)
+ if not success:
+ sys.exit(1)
+ except Exception as e:
+ print(f"Conversion failed: {e}")
+ import traceback
+
+ traceback.print_exc()
+ sys.exit(1)
diff --git a/scripts/convert_sana_video_to_diffusers.py b/scripts/convert_sana_video_to_diffusers.py
new file mode 100644
index 0000000000..fbb7c1d9e7
--- /dev/null
+++ b/scripts/convert_sana_video_to_diffusers.py
@@ -0,0 +1,324 @@
+#!/usr/bin/env python
+from __future__ import annotations
+
+import argparse
+import os
+from contextlib import nullcontext
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download, snapshot_download
+from termcolor import colored
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from diffusers import (
+ AutoencoderKLWan,
+ DPMSolverMultistepScheduler,
+ FlowMatchEulerDiscreteScheduler,
+ SanaVideoPipeline,
+ SanaVideoTransformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.import_utils import is_accelerate_available
+
+
+CTX = init_empty_weights if is_accelerate_available else nullcontext
+
+ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
+# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
+
+
+def main(args):
+ cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
+
+ if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids:
+ ckpt_id = args.orig_ckpt_path or ckpt_ids[0]
+ snapshot_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ file_path = hf_hub_download(
+ repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}",
+ filename=f"{'/'.join(ckpt_id.split('/')[2:])}",
+ cache_dir=cache_dir_path,
+ repo_type="model",
+ )
+ else:
+ file_path = args.orig_ckpt_path
+
+ print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"]))
+ all_state_dict = torch.load(file_path, weights_only=True)
+ state_dict = all_state_dict.pop("state_dict")
+ converted_state_dict = {}
+
+ # Patch embeddings.
+ converted_state_dict["patch_embedding.weight"] = state_dict.pop("x_embedder.proj.weight")
+ converted_state_dict["patch_embedding.bias"] = state_dict.pop("x_embedder.proj.bias")
+
+ # Caption projection.
+ converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight")
+ converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias")
+ converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight")
+ converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias")
+
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop(
+ "t_embedder.mlp.0.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias")
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop(
+ "t_embedder.mlp.2.weight"
+ )
+ converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias")
+
+ # Shared norm.
+ converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight")
+ converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias")
+
+ # y norm
+ converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
+
+ # scheduler
+ flow_shift = 8.0
+
+ # model config
+ layer_num = 20
+ # Positional embedding interpolation scale.
+ qk_norm = True
+
+ # sample size
+ if args.video_size == 480:
+ sample_size = 30 # Wan-VAE: 8xp2 downsample factor
+ patch_size = (1, 2, 2)
+ elif args.video_size == 720:
+ sample_size = 22 # Wan-VAE: 32xp1 downsample factor
+ patch_size = (1, 1, 1)
+ else:
+ raise ValueError(f"Video size {args.video_size} is not supported.")
+
+ for depth in range(layer_num):
+ # Transformer blocks.
+ converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
+ f"blocks.{depth}.scale_shift_table"
+ )
+
+ # Linear Attention is all you need 🤘
+ # Self attention.
+ q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
+ if qk_norm is not None:
+ # Add Q/K normalization for self-attention (attn1) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.k_norm.weight"
+ )
+ # Projection.
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.attn.proj.bias"
+ )
+
+ # Feed-forward.
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.inverted_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop(
+ f"blocks.{depth}.mlp.depth_conv.conv.bias"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.point_conv.conv.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.ff.conv_temp.weight"] = state_dict.pop(
+ f"blocks.{depth}.mlp.t_conv.weight"
+ )
+
+ # Cross-attention.
+ q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight")
+ q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias")
+ k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0)
+ k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0)
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias
+ if qk_norm is not None:
+ # Add Q/K normalization for cross-attention (attn2) - needed for Sana-Sprint and Sana-1.5
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_q.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.q_norm.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.norm_k.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.k_norm.weight"
+ )
+
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.weight"
+ )
+ converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop(
+ f"blocks.{depth}.cross_attn.proj.bias"
+ )
+
+ # Final block.
+ converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight")
+ converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias")
+ converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table")
+
+ # Transformer
+ with CTX():
+ transformer_kwargs = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 20,
+ "attention_head_dim": 112,
+ "num_layers": 20,
+ "num_cross_attention_heads": 20,
+ "cross_attention_head_dim": 112,
+ "cross_attention_dim": 2240,
+ "caption_channels": 2304,
+ "mlp_ratio": 3.0,
+ "attention_bias": False,
+ "sample_size": sample_size,
+ "patch_size": patch_size,
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 1024,
+ }
+
+ transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
+
+ transformer.load_state_dict(converted_state_dict, strict=True, assign=True)
+
+ try:
+ state_dict.pop("y_embedder.y_embedding")
+ state_dict.pop("pos_embed")
+ state_dict.pop("logvar_linear.weight")
+ state_dict.pop("logvar_linear.bias")
+ except KeyError:
+ print("y_embedder.y_embedding or pos_embed not found in the state_dict")
+
+ assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
+
+ num_model_params = sum(p.numel() for p in transformer.parameters())
+ print(f"Total number of transformer parameters: {num_model_params}")
+
+ transformer = transformer.to(weight_dtype)
+
+ if not args.save_full_pipeline:
+ print(
+ colored(
+ f"Only saving transformer model of {args.model_type}. "
+ f"Set --save_full_pipeline to save the whole Pipeline",
+ "green",
+ attrs=["bold"],
+ )
+ )
+ transformer.save_pretrained(
+ os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB"
+ )
+ else:
+ print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
+ # VAE
+ vae = AutoencoderKLWan.from_pretrained(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+
+ # Text Encoder
+ text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
+ tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path)
+ tokenizer.padding_side = "right"
+ text_encoder = AutoModelForCausalLM.from_pretrained(
+ text_encoder_model_path, torch_dtype=torch.bfloat16
+ ).get_decoder()
+
+ # Choose the appropriate pipeline and scheduler based on model type
+ # Original Sana scheduler
+ if args.scheduler_type == "flow-dpm_solver":
+ scheduler = DPMSolverMultistepScheduler(
+ flow_shift=flow_shift,
+ use_flow_sigmas=True,
+ prediction_type="flow_prediction",
+ )
+ elif args.scheduler_type == "flow-euler":
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
+ elif args.scheduler_type == "uni-pc":
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction",
+ use_flow_sigmas=True,
+ num_train_timesteps=1000,
+ flow_shift=flow_shift,
+ )
+ else:
+ raise ValueError(f"Scheduler type {args.scheduler_type} is not supported")
+
+ pipe = SanaVideoPipeline(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ transformer=transformer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB")
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert."
+ )
+ parser.add_argument(
+ "--video_size",
+ default=480,
+ type=int,
+ choices=[480, 720],
+ required=False,
+ help="Video size of pretrained model, 480 or 720.",
+ )
+ parser.add_argument(
+ "--model_type",
+ default="SanaVideo",
+ type=str,
+ choices=[
+ "SanaVideo",
+ ],
+ )
+ parser.add_argument(
+ "--scheduler_type",
+ default="flow-dpm_solver",
+ type=str,
+ choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
+ help="Scheduler type to use.",
+ )
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
+ parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
+ parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
+
+ args = parser.parse_args()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ weight_dtype = DTYPE_MAPPING[args.dtype]
+
+ main(args)
diff --git a/setup.py b/setup.py
index ba3ad8e2b3..8d346ddfec 100644
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,8 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.34.0",
+ "httpx<1.0.0",
+ "huggingface-hub>=0.34.0,<2.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -144,6 +145,7 @@ _deps = [
"black",
"phonemizer",
"opencv-python",
+ "timm",
]
# this is a lookup table with items like:
@@ -217,7 +219,7 @@ class DepsTableUpdateCommand(Command):
extras = {}
extras["quality"] = deps_list("urllib3", "isort", "ruff", "hf-doc-builder")
extras["docs"] = deps_list("hf-doc-builder")
-extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft")
+extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
extras["test"] = deps_list(
"compel",
"GitPython",
@@ -259,6 +261,7 @@ extras["dev"] = (
install_requires = [
deps["importlib_metadata"],
deps["filelock"],
+ deps["httpx"],
deps["huggingface-hub"],
deps["numpy"],
deps["regex"],
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index d96acc3818..572aad4bd3 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -149,7 +149,9 @@ else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
+ "AdaptiveProjectedMixGuidance",
"AutoGuidance",
+ "BaseGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"FrequencyDecoupledGuidance",
@@ -184,6 +186,8 @@ else:
"AutoencoderKLAllegro",
"AutoencoderKLCogVideoX",
"AutoencoderKLCosmos",
+ "AutoencoderKLHunyuanImage",
+ "AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
@@ -194,6 +198,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"AutoModel",
+ "BriaFiboTransformer2DModel",
"BriaTransformer2DModel",
"CacheMixin",
"ChromaTransformer2DModel",
@@ -202,6 +207,7 @@ else:
"CogView4Transformer2DModel",
"ConsisIDTransformer3DModel",
"ConsistencyDecoderVAE",
+ "ContextParallelConfig",
"ControlNetModel",
"ControlNetUnionModel",
"ControlNetXSAdapter",
@@ -215,10 +221,12 @@ else:
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
"HunyuanDiT2DMultiControlNetModel",
+ "HunyuanImageTransformer2DModel",
"HunyuanVideoFramepackTransformer3DModel",
"HunyuanVideoTransformer3DModel",
"I2VGenXLUNet",
"Kandinsky3UNet",
+ "Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
@@ -229,13 +237,16 @@ else:
"MultiAdapter",
"MultiControlNetModel",
"OmniGenTransformer2DModel",
+ "ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
+ "PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
"SanaControlNetModel",
"SanaTransformer2DModel",
+ "SanaVideoTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
@@ -384,10 +395,14 @@ else:
_import_structure["modular_pipelines"].extend(
[
"FluxAutoBlocks",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
"FluxModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
"QwenImageEditModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
@@ -417,6 +432,7 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
+ "BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaPipeline",
@@ -454,6 +470,8 @@ else:
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
"HunyuanDiTPipeline",
+ "HunyuanImagePipeline",
+ "HunyuanImageRefinerPipeline",
"HunyuanSkyreelsImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
"HunyuanVideoImageToVideoPipeline",
@@ -468,6 +486,7 @@ else:
"ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
+ "Kandinsky5T2VPipeline",
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline",
@@ -495,6 +514,7 @@ else:
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
+ "LucyEditPipeline",
"Lumina2Pipeline",
"Lumina2Text2ImgPipeline",
"LuminaPipeline",
@@ -510,10 +530,12 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
+ "PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImagePipeline",
@@ -523,6 +545,7 @@ else:
"SanaPipeline",
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
+ "SanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
@@ -837,7 +860,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .guiders import (
AdaptiveProjectedGuidance,
+ AdaptiveProjectedMixGuidance,
AutoGuidance,
+ BaseGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
FrequencyDecoupledGuidance,
@@ -868,6 +893,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
+ AutoencoderKLHunyuanImage,
+ AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -878,6 +905,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
AutoModel,
+ BriaFiboTransformer2DModel,
BriaTransformer2DModel,
CacheMixin,
ChromaTransformer2DModel,
@@ -886,6 +914,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4Transformer2DModel,
ConsisIDTransformer3DModel,
ConsistencyDecoderVAE,
+ ContextParallelConfig,
ControlNetModel,
ControlNetUnionModel,
ControlNetXSAdapter,
@@ -899,10 +928,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
HunyuanDiT2DMultiControlNetModel,
+ HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
I2VGenXLUNet,
Kandinsky3UNet,
+ Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
@@ -913,13 +944,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiAdapter,
MultiControlNetModel,
OmniGenTransformer2DModel,
+ ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
+ PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
@@ -1042,10 +1076,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
FluxAutoBlocks,
+ FluxKontextAutoBlocks,
+ FluxKontextModularPipeline,
FluxModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
@@ -1071,6 +1109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
+ BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaPipeline,
@@ -1108,6 +1147,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
HunyuanDiTPipeline,
+ HunyuanImagePipeline,
+ HunyuanImageRefinerPipeline,
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
HunyuanVideoImageToVideoPipeline,
@@ -1122,6 +1163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
+ Kandinsky5T2VPipeline,
KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline,
@@ -1149,6 +1191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
+ LucyEditPipeline,
Lumina2Pipeline,
Lumina2Text2ImgPipeline,
LuminaPipeline,
@@ -1164,10 +1207,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
+ PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
@@ -1177,6 +1222,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SanaPipeline,
SanaSprintImg2ImgPipeline,
SanaSprintPipeline,
+ SanaVideoPipeline,
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index 540aab0307..1c4ee33acb 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -30,11 +30,11 @@ import numpy as np
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from typing_extensions import Self
from . import __version__
@@ -419,7 +419,7 @@ class ConfigMixin:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 79dc4c50a0..6e5ac630ab 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -9,7 +9,8 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.34.0",
+ "httpx": "httpx<1.0.0",
+ "huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -51,4 +52,5 @@ deps = {
"black": "black",
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
+ "timm": "timm",
}
diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py
index 23cb7a0a71..4e53c373c4 100644
--- a/src/diffusers/guiders/__init__.py
+++ b/src/diffusers/guiders/__init__.py
@@ -14,28 +14,18 @@
from typing import Union
-from ..utils import is_torch_available
+from ..utils import is_torch_available, logging
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
+ from .adaptive_projected_guidance_mix import AdaptiveProjectedMixGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
+ from .guider_utils import BaseGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
-
- GuiderType = Union[
- AdaptiveProjectedGuidance,
- AutoGuidance,
- ClassifierFreeGuidance,
- ClassifierFreeZeroStarGuidance,
- FrequencyDecoupledGuidance,
- PerturbedAttentionGuidance,
- SkipLayerGuidance,
- SmoothedEnergyGuidance,
- TangentialClassifierFreeGuidance,
- ]
diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py
index 92b1fd5a1c..492d10d2f1 100644
--- a/src/diffusers/guiders/adaptive_projected_guidance.py
+++ b/src/diffusers/guiders/adaptive_projected_guidance.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -65,8 +65,9 @@ class AdaptiveProjectedGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
@@ -76,19 +77,14 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
@@ -152,6 +148,44 @@ class MomentumBuffer:
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
+ def __repr__(self) -> str:
+ """
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
+ """
+ if isinstance(self.running_average, torch.Tensor):
+ shape = tuple(self.running_average.shape)
+
+ # Calculate statistics
+ with torch.no_grad():
+ stats = {
+ "mean": self.running_average.mean().item(),
+ "std": self.running_average.std().item(),
+ "min": self.running_average.min().item(),
+ "max": self.running_average.max().item(),
+ }
+
+ # Get a slice (max 3 elements per dimension)
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
+ sliced_data = self.running_average[slice_indices]
+
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
+ if len(slice_str) > 200: # Truncate if too long
+ slice_str = slice_str[:200] + "..."
+
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
+
+ return (
+ f"MomentumBuffer(\n"
+ f" momentum={self.momentum},\n"
+ f" shape={shape},\n"
+ f" stats=[{stats_str}],\n"
+ f" slice={slice_str}\n"
+ f")"
+ )
+ else:
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
+
def normalized_guidance(
pred_cond: torch.Tensor,
diff --git a/src/diffusers/guiders/adaptive_projected_guidance_mix.py b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
new file mode 100644
index 0000000000..732741fc92
--- /dev/null
+++ b/src/diffusers/guiders/adaptive_projected_guidance_mix.py
@@ -0,0 +1,284 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AdaptiveProjectedMixGuidance(BaseGuidance):
+ """
+ Adaptive Projected Guidance (APG) https://huggingface.co/papers/2410.02416 combined with Classifier-Free Guidance
+ (CFG). This guider is used in HunyuanImage2.1 https://github.com/Tencent-Hunyuan/HunyuanImage-2.1
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
+ The rescale factor applied to the noise predictions for adaptive projected guidance. This is used to
+ improve image quality and fix
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions for classifier-free guidance. This is used to improve
+ image quality and fix overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample
+ Steps are Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which the classifier-free guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which the classifier-free guidance stops.
+ adaptive_projected_guidance_start_step (`int`, defaults to `5`):
+ The step at which the adaptive projected guidance starts (before this step, classifier-free guidance is
+ used, and momentum buffer is updated).
+ enabled (`bool`, defaults to `True`):
+ Whether this guidance is enabled.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 3.5,
+ guidance_rescale: float = 0.0,
+ adaptive_projected_guidance_scale: float = 10.0,
+ adaptive_projected_guidance_momentum: float = -0.5,
+ adaptive_projected_guidance_rescale: float = 10.0,
+ eta: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ adaptive_projected_guidance_start_step: int = 5,
+ enabled: bool = True,
+ ):
+ super().__init__(start, stop, enabled)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.adaptive_projected_guidance_scale = adaptive_projected_guidance_scale
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
+ self.eta = eta
+ self.adaptive_projected_guidance_start_step = adaptive_projected_guidance_start_step
+ self.use_original_formulation = use_original_formulation
+ self.momentum_buffer = None
+
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
+ pred = None
+
+ # no guidance
+ if not self._is_cfg_enabled():
+ pred = pred_cond
+
+ # CFG + update momentum buffer
+ elif not self._is_apg_enabled():
+ if self.momentum_buffer is not None:
+ update_momentum_buffer(pred_cond, pred_uncond, self.momentum_buffer)
+ # CFG + update momentum buffer
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ # APG
+ elif self._is_apg_enabled():
+ pred = normalized_guidance(
+ pred_cond,
+ pred_uncond,
+ self.adaptive_projected_guidance_scale,
+ self.momentum_buffer,
+ self.eta,
+ self.adaptive_projected_guidance_rescale,
+ self.use_original_formulation,
+ )
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_apg_enabled() or self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ # Copied from diffusers.guiders.classifier_free_guidance.ClassifierFreeGuidance._is_cfg_enabled
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_apg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ if not self._is_cfg_enabled():
+ return False
+
+ is_within_range = False
+ if self._step is not None:
+ is_within_range = self._step > self.adaptive_projected_guidance_start_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.adaptive_projected_guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def get_state(self):
+ state = super().get_state()
+ state["momentum_buffer"] = self.momentum_buffer
+ state["is_apg_enabled"] = self._is_apg_enabled()
+ state["is_cfg_enabled"] = self._is_cfg_enabled()
+ return state
+
+
+# Copied from diffusers.guiders.adaptive_projected_guidance.MomentumBuffer
+class MomentumBuffer:
+ def __init__(self, momentum: float):
+ self.momentum = momentum
+ self.running_average = 0
+
+ def update(self, update_value: torch.Tensor):
+ new_average = self.momentum * self.running_average
+ self.running_average = update_value + new_average
+
+ def __repr__(self) -> str:
+ """
+ Returns a string representation showing momentum, shape, statistics, and a slice of the running_average.
+ """
+ if isinstance(self.running_average, torch.Tensor):
+ shape = tuple(self.running_average.shape)
+
+ # Calculate statistics
+ with torch.no_grad():
+ stats = {
+ "mean": self.running_average.mean().item(),
+ "std": self.running_average.std().item(),
+ "min": self.running_average.min().item(),
+ "max": self.running_average.max().item(),
+ }
+
+ # Get a slice (max 3 elements per dimension)
+ slice_indices = tuple(slice(None, min(3, dim)) for dim in shape)
+ sliced_data = self.running_average[slice_indices]
+
+ # Format the slice for display (convert to float32 for numpy compatibility with bfloat16)
+ slice_str = str(sliced_data.detach().float().cpu().numpy())
+ if len(slice_str) > 200: # Truncate if too long
+ slice_str = slice_str[:200] + "..."
+
+ stats_str = ", ".join([f"{k}={v:.4f}" for k, v in stats.items()])
+
+ return (
+ f"MomentumBuffer(\n"
+ f" momentum={self.momentum},\n"
+ f" shape={shape},\n"
+ f" stats=[{stats_str}],\n"
+ f" slice={slice_str}\n"
+ f")"
+ )
+ else:
+ return f"MomentumBuffer(momentum={self.momentum}, running_average={self.running_average})"
+
+
+def update_momentum_buffer(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+):
+ diff = pred_cond - pred_uncond
+ if momentum_buffer is not None:
+ momentum_buffer.update(diff)
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ guidance_scale: float,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+ eta: float = 1.0,
+ norm_threshold: float = 0.0,
+ use_original_formulation: bool = False,
+):
+ if momentum_buffer is not None:
+ update_momentum_buffer(pred_cond, pred_uncond, momentum_buffer)
+ diff = momentum_buffer.running_average
+ else:
+ diff = pred_cond - pred_uncond
+
+ dim = [-i for i in range(1, len(diff.shape))]
+
+ if norm_threshold > 0:
+ ones = torch.ones_like(diff)
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
+ diff = diff * scale_factor
+
+ v0, v1 = diff.double(), pred_cond.double()
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
+ normalized_update = diff_orthogonal + eta * diff_parallel
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ pred = pred + guidance_scale * normalized_update
+
+ return pred
diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py
index 5271a530ea..4374f45aff 100644
--- a/src/diffusers/guiders/auto_guidance.py
+++ b/src/diffusers/guiders/auto_guidance.py
@@ -72,8 +72,9 @@ class AutoGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
@@ -132,16 +133,11 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py
index 050590336f..d475b30226 100644
--- a/src/diffusers/guiders/classifier_free_guidance.py
+++ b/src/diffusers/guiders/classifier_free_guidance.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -27,43 +27,50 @@ if TYPE_CHECKING:
class ClassifierFreeGuidance(BaseGuidance):
"""
- Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
+ Implements Classifier-Free Guidance (CFG) for diffusion models.
- CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
- jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
- inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
- proposes scaling and shifting the conditional distribution based on the difference between conditional and
- unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
+ Reference: https://huggingface.co/papers/2207.12598
- Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
- paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
- theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
+ CFG improves generation quality and prompt adherence by jointly training models on both conditional and
+ unconditional data, then combining predictions during inference. This allows trading off between quality (high
+ guidance) and diversity (low guidance).
- The intution behind the original formulation can be thought of as moving the conditional distribution estimates
- further away from the unconditional distribution estimates, while the diffusers-native implementation can be
- thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
- the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
+ **Two CFG Formulations:**
- The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
- paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
+ 1. **Original formulation** (from paper):
+ ```
+ x_pred = x_cond + guidance_scale * (x_cond - x_uncond)
+ ```
+ Moves conditional predictions further from unconditional ones.
+
+ 2. **Diffusers-native formulation** (default, from Imagen paper):
+ ```
+ x_pred = x_uncond + guidance_scale * (x_cond - x_uncond)
+ ```
+ Moves unconditional predictions toward conditional ones, effectively suppressing negative features (e.g., "bad
+ quality", "watermarks"). Equivalent in theory but more intuitive.
+
+ Use `use_original_formulation=True` to switch to the original formulation.
Args:
guidance_scale (`float`, defaults to `7.5`):
- The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
- prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
- deterioration of image quality.
+ CFG scale applied by this guider during post-processing. Higher values = stronger prompt conditioning but
+ may reduce quality. Typical range: 1.0-20.0.
guidance_rescale (`float`, defaults to `0.0`):
- The rescale factor applied to the noise predictions. This is used to improve image quality and fix
- overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
- Flawed](https://huggingface.co/papers/2305.08891).
+ Rescaling factor to prevent overexposure from high guidance scales. Based on [Common Diffusion Noise
+ Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891). Range: 0.0 (no rescaling)
+ to 1.0 (full rescaling).
use_original_formulation (`bool`, defaults to `False`):
- Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
- we use the diffusers-native implementation that has been in the codebase for a long time. See
- [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ If `True`, uses the original CFG formulation from the paper. If `False` (default), uses the
+ diffusers-native formulation from the Imagen paper.
start (`float`, defaults to `0.0`):
- The fraction of the total number of denoising steps after which guidance starts.
+ Fraction of denoising steps (0.0-1.0) after which CFG starts. Use > 0.0 to disable CFG in early denoising
+ steps.
stop (`float`, defaults to `1.0`):
- The fraction of the total number of denoising steps after which guidance stops.
+ Fraction of denoising steps (0.0-1.0) after which CFG stops. Use < 1.0 to disable CFG in late denoising
+ steps.
+ enabled (`bool`, defaults to `True`):
+ Whether CFG is enabled. Set to `False` to disable CFG entirely (uses only conditional predictions).
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@@ -76,23 +83,19 @@ class ClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
index b64e356331..1ea6bbb1c8 100644
--- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py
+++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -68,31 +68,31 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
- if self._step < self.zero_init_steps:
+ # YiYi Notes: add default behavior for self._enabled == False
+ if not self._enabled:
+ pred = pred_cond
+
+ elif self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
diff --git a/src/diffusers/guiders/frequency_decoupled_guidance.py b/src/diffusers/guiders/frequency_decoupled_guidance.py
index 93822a180e..cd542a43a4 100644
--- a/src/diffusers/guiders/frequency_decoupled_guidance.py
+++ b/src/diffusers/guiders/frequency_decoupled_guidance.py
@@ -149,6 +149,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
stop: Union[float, List[float], Tuple[float]] = 1.0,
guidance_rescale_space: str = "data",
upcast_to_double: bool = True,
+ enabled: bool = True,
):
if not _CAN_USE_KORNIA:
raise ImportError(
@@ -160,7 +161,7 @@ class FrequencyDecoupledGuidance(BaseGuidance):
# Set start to earliest start for any freq component and stop to latest stop for any freq component
min_start = start if isinstance(start, float) else min(start)
max_stop = stop if isinstance(stop, float) else max(stop)
- super().__init__(min_start, max_stop)
+ super().__init__(min_start, max_stop, enabled)
self.guidance_scales = guidance_scales
self.levels = len(guidance_scales)
@@ -217,16 +218,11 @@ class FrequencyDecoupledGuidance(BaseGuidance):
f"({len(self.guidance_scales)})"
)
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py
index a6f2e76dc3..71e4becfcd 100644
--- a/src/diffusers/guiders/guider_utils.py
+++ b/src/diffusers/guiders/guider_utils.py
@@ -40,7 +40,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
_input_predictions = None
_identifier_key = "__guidance_identifier__"
- def __init__(self, start: float = 0.0, stop: float = 1.0):
+ def __init__(self, start: float = 0.0, stop: float = 1.0, enabled: bool = True):
+ logger.warning(
+ "Guiders are currently an experimental feature under active development. The API is subject to breaking changes in future releases."
+ )
+
self._start = start
self._stop = stop
self._step: int = None
@@ -48,7 +52,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
- self._enabled = True
+ self._enabled = enabled
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
@@ -60,6 +64,31 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
+ def new(self, **kwargs):
+ """
+ Creates a copy of this guider instance, optionally with modified configuration parameters.
+
+ Args:
+ **kwargs: Configuration parameters to override in the new instance. If no kwargs are provided,
+ returns an exact copy with the same configuration.
+
+ Returns:
+ A new guider instance with the same (or updated) configuration.
+
+ Example:
+ ```python
+ # Create a CFG guider
+ guider = ClassifierFreeGuidance(guidance_scale=3.5)
+
+ # Create an exact copy
+ same_guider = guider.new()
+
+ # Create a copy with different start step, keeping other config the same
+ new_guider = guider.new(guidance_scale=5)
+ ```
+ """
+ return self.__class__.from_config(self.config, **kwargs)
+
def disable(self):
self._enabled = False
@@ -72,42 +101,52 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep = timestep
self._count_prepared = 0
- def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
+ def get_state(self) -> Dict[str, Any]:
"""
- Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
- attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
- the values of the provided keyword arguments to this method.
-
- Args:
- **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
- A dictionary where the keys are the names of the fields that will be used to store the data once it is
- prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
- to look up the required data provided for preparation.
-
- If a string is provided, it will be used as the conditional data (or unconditional if used with a
- guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
- conditional data identifier and the second element must be the unconditional data identifier or None.
-
- Example:
- ```
- data = {"prompt_embeds":
, "negative_prompt_embeds": , "latents": }
-
- BaseGuidance.set_input_fields(
- latents="latents",
- prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
- )
- ```
+ Returns the current state of the guidance technique as a dictionary. The state variables will be included in
+ the __repr__ method. Returns:
+ `Dict[str, Any]`: A dictionary containing the current state variables including:
+ - step: Current inference step
+ - num_inference_steps: Total number of inference steps
+ - timestep: Current timestep tensor
+ - count_prepared: Number of times prepare_models has been called
+ - enabled: Whether the guidance is enabled
+ - num_conditions: Number of conditions
"""
- for key, value in kwargs.items():
- is_string = isinstance(value, str)
- is_tuple_of_str_with_len_2 = (
- isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
- )
- if not (is_string or is_tuple_of_str_with_len_2):
- raise ValueError(
- f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
- )
- self._input_fields = kwargs
+ state = {
+ "step": self._step,
+ "num_inference_steps": self._num_inference_steps,
+ "timestep": self._timestep,
+ "count_prepared": self._count_prepared,
+ "enabled": self._enabled,
+ "num_conditions": self.num_conditions,
+ }
+ return state
+
+ def __repr__(self) -> str:
+ """
+ Returns a string representation of the guidance object including both config and current state.
+ """
+ # Get ConfigMixin's __repr__
+ str_repr = super().__repr__()
+
+ # Get current state
+ state = self.get_state()
+
+ # Format each state variable on its own line with indentation
+ state_lines = []
+ for k, v in state.items():
+ # Convert value to string and handle multi-line values
+ v_str = str(v)
+ if "\n" in v_str:
+ # For multi-line values (like MomentumBuffer), indent subsequent lines
+ v_lines = v_str.split("\n")
+ v_str = v_lines[0] + "\n" + "\n".join([" " + line for line in v_lines[1:]])
+ state_lines.append(f" {k}: {v_str}")
+
+ state_str = "\n".join(state_lines)
+
+ return f"{str_repr}\nState:\n{state_str}"
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
@@ -155,8 +194,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
@classmethod
def _prepare_batch(
cls,
- input_fields: Dict[str, Union[str, Tuple[str, str]]],
- data: "BlockState",
+ data: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
tuple_index: int,
identifier: str,
) -> "BlockState":
@@ -182,21 +220,16 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
"""
from ..modular_pipelines.modular_pipeline import BlockState
- if input_fields is None:
- raise ValueError(
- "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
- )
data_batch = {}
- for key, value in input_fields.items():
+ for key, value in data.items():
try:
- if isinstance(value, str):
- data_batch[key] = getattr(data, value)
+ if isinstance(value, torch.Tensor):
+ data_batch[key] = value
elif isinstance(value, tuple):
- data_batch[key] = getattr(data, value[tuple_index])
+ data_batch[key] = value[tuple_index]
else:
- # We've already checked that value is a string or a tuple of strings with length 2
- pass
- except AttributeError:
+ raise ValueError(f"Invalid value type: {type(value)}")
+ except ValueError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@@ -247,15 +280,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py
index e294e8d0db..29341736e8 100644
--- a/src/diffusers/guiders/perturbed_attention_guidance.py
+++ b/src/diffusers/guiders/perturbed_attention_guidance.py
@@ -98,8 +98,9 @@ class PerturbedAttentionGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = perturbed_guidance_scale
@@ -168,12 +169,7 @@ class PerturbedAttentionGuidance(BaseGuidance):
registry.remove_hook(hook_name, recurse=True)
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -186,8 +182,8 @@ class PerturbedAttentionGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py
index 3530df8b0a..fa5b93b680 100644
--- a/src/diffusers/guiders/skip_layer_guidance.py
+++ b/src/diffusers/guiders/skip_layer_guidance.py
@@ -100,8 +100,9 @@ class SkipLayerGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
@@ -164,12 +165,7 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -182,8 +178,8 @@ class SkipLayerGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py
index 767d20b62f..7446b33f12 100644
--- a/src/diffusers/guiders/smoothed_energy_guidance.py
+++ b/src/diffusers/guiders/smoothed_energy_guidance.py
@@ -92,8 +92,9 @@ class SmoothedEnergyGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
@@ -153,12 +154,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -171,8 +167,8 @@ class SmoothedEnergyGuidance(BaseGuidance):
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py
index df1e69fe71..cfa3c4a616 100644
--- a/src/diffusers/guiders/tangential_classifier_free_guidance.py
+++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py
@@ -13,7 +13,7 @@
# limitations under the License.
import math
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
import torch
@@ -58,23 +58,19 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
+ enabled: bool = True,
):
- super().__init__(start, stop)
+ super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
- def prepare_inputs(
- self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
- ) -> List["BlockState"]:
- if input_fields is None:
- input_fields = self._input_fields
-
+ def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
- for i in range(self.num_conditions):
- data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
+ data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 525a0747da..524a92ea99 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -16,6 +16,7 @@ from ..utils import is_torch_available
if is_torch_available():
+ from .context_parallel import apply_context_parallel
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
index f6e5bdd52d..790199f3c9 100644
--- a/src/diffusers/hooks/_helpers.py
+++ b/src/diffusers/hooks/_helpers.py
@@ -108,6 +108,7 @@ def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
from ..models.transformers.transformer_flux import FluxAttnProcessor
+ from ..models.transformers.transformer_hunyuanimage import HunyuanImageAttnProcessor
from ..models.transformers.transformer_qwenimage import QwenDoubleStreamAttnProcessor2_0
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
@@ -149,6 +150,14 @@ def _register_attention_processors_metadata():
),
)
+ # HunyuanImageAttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=HunyuanImageAttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor,
+ ),
+ )
+
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
@@ -162,6 +171,10 @@ def _register_transformer_blocks_metadata():
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
+ from ..models.transformers.transformer_hunyuanimage import (
+ HunyuanImageSingleTransformerBlock,
+ HunyuanImageTransformerBlock,
+ )
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
@@ -283,6 +296,22 @@ def _register_transformer_blocks_metadata():
),
)
+ # HunyuanImage2.1
+ TransformerBlockRegistry.register(
+ model_class=HunyuanImageTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanImageSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
@@ -308,4 +337,5 @@ _skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hid
# not sure what this is yet.
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_QwenDoubleStreamAttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_HunyuanImageAttnProcessor = _skip_attention___ret___hidden_states
# fmt: on
diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py
new file mode 100644
index 0000000000..915fe453b9
--- /dev/null
+++ b/src/diffusers/hooks/context_parallel.py
@@ -0,0 +1,300 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Dict, List, Type, Union
+
+import torch
+
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
+from ..models._modeling_parallel import (
+ ContextParallelConfig,
+ ContextParallelInput,
+ ContextParallelModelPlan,
+ ContextParallelOutput,
+)
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
+_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
+
+
+# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
+@dataclass
+class ModuleForwardMetadata:
+ cached_parameter_indices: Dict[str, int] = None
+ _cls: Type = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+
+ if identifier in kwargs:
+ return kwargs[identifier], True, None
+
+ if self.cached_parameter_indices is not None:
+ index = self.cached_parameter_indices.get(identifier, None)
+ if index is None:
+ raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
+ return args[index], False, index
+
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+
+ if identifier not in self.cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+
+ index = self.cached_parameter_indices[identifier]
+
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+
+ return args[index], False, index
+
+
+def apply_context_parallel(
+ module: torch.nn.Module,
+ parallel_config: ContextParallelConfig,
+ plan: Dict[str, ContextParallelModelPlan],
+) -> None:
+ """Apply context parallel on a model."""
+ logger.debug(f"Applying context parallel with CP mesh: {parallel_config._mesh} and plan: {plan}")
+
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
+
+ for m in submodule:
+ if isinstance(cp_model_plan, dict):
+ hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ if isinstance(cp_model_plan, ContextParallelOutput):
+ cp_model_plan = [cp_model_plan]
+ if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
+ raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
+ hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ registry.register_hook(hook, hook_name)
+
+
+def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
+ for module_id, cp_model_plan in plan.items():
+ submodule = _get_submodule_by_name(module, module_id)
+ if not isinstance(submodule, list):
+ submodule = [submodule]
+
+ for m in submodule:
+ registry = HookRegistry.check_if_exists_or_initialize(m)
+ if isinstance(cp_model_plan, dict):
+ hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
+ elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
+ hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
+ else:
+ raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
+ registry.remove_hook(hook_name)
+
+
+class ContextParallelSplitHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+ self.module_forward_metadata = None
+
+ def initialize_hook(self, module):
+ cls = unwrap_module(module).__class__
+ self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
+ return module
+
+ def pre_forward(self, module, *args, **kwargs):
+ args_list = list(args)
+
+ for name, cpm in self.metadata.items():
+ if isinstance(cpm, ContextParallelInput) and cpm.split_output:
+ continue
+
+ # Maybe the parameter was passed as a keyword argument
+ input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
+ name, args_list, kwargs
+ )
+
+ if input_val is None:
+ continue
+
+ # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
+ # the output instead of input for a particular layer by setting split_output=True
+ if isinstance(input_val, torch.Tensor):
+ input_val = self._prepare_cp_input(input_val, cpm)
+ elif isinstance(input_val, (list, tuple)):
+ if len(input_val) != len(cpm):
+ raise ValueError(
+ f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
+ )
+ sharded_input_val = []
+ for i, x in enumerate(input_val):
+ if torch.is_tensor(x) and not cpm[i].split_output:
+ x = self._prepare_cp_input(x, cpm[i])
+ sharded_input_val.append(x)
+ input_val = sharded_input_val
+ else:
+ raise ValueError(f"Unsupported input type: {type(input_val)}")
+
+ if is_kwarg:
+ kwargs[name] = input_val
+ elif index is not None and index < len(args_list):
+ args_list[index] = input_val
+ else:
+ raise ValueError(
+ f"An unexpected error occurred while processing the input '{name}'. Please open an "
+ f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
+ f"example along with the full stack trace."
+ )
+
+ return tuple(args_list), kwargs
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+ is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
+
+ if not is_tensor and not is_tensor_list:
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = [output] if is_tensor else list(output)
+ for index, cpm in self.metadata.items():
+ if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
+ continue
+ if index >= len(output):
+ raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
+ current_output = output[index]
+ current_output = self._prepare_cp_input(current_output, cpm)
+ output[index] = current_output
+
+ return output[0] if is_tensor else tuple(output)
+
+ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
+ if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
+ raise ValueError(
+ f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
+ )
+ return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
+
+
+class ContextParallelGatherHook(ModelHook):
+ def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ContextParallelConfig) -> None:
+ super().__init__()
+ self.metadata = metadata
+ self.parallel_config = parallel_config
+
+ def post_forward(self, module, output):
+ is_tensor = isinstance(output, torch.Tensor)
+
+ if is_tensor:
+ output = [output]
+ elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
+ raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
+
+ output = list(output)
+
+ if len(output) != len(self.metadata):
+ raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
+
+ for i, cpm in enumerate(self.metadata):
+ if cpm is None:
+ continue
+ output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
+
+ return output[0] if is_tensor else tuple(output)
+
+
+class AllGatherFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, tensor, dim, group):
+ ctx.dim = dim
+ ctx.group = group
+ ctx.world_size = torch.distributed.get_world_size(group)
+ ctx.rank = torch.distributed.get_rank(group)
+ return funcol.all_gather_tensor(tensor, dim, group=group)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_chunks = torch.chunk(grad_output, ctx.world_size, dim=ctx.dim)
+ return grad_chunks[ctx.rank], None, None
+
+
+class EquipartitionSharder:
+ @classmethod
+ def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ # NOTE: the following assertion does not have to be true in general. We simply enforce it for now
+ # because the alternate case has not yet been tested/required for any model.
+ assert tensor.size()[dim] % mesh.size() == 0, (
+ "Tensor size along dimension to be sharded must be divisible by mesh size"
+ )
+
+ # The following is not fullgraph compatible with Dynamo (fails in DeviceMesh.get_rank)
+ # return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
+
+ return tensor.chunk(mesh.size(), dim=dim)[torch.distributed.get_rank(mesh.get_group())]
+
+ @classmethod
+ def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ tensor = AllGatherFunction.apply(tensor, dim, mesh.get_group())
+ return tensor
+
+
+def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name.count("*") > 1:
+ raise ValueError("Wildcard '*' can only be used once in the name")
+ return _find_submodule_by_name(model, name)
+
+
+def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name == "":
+ return model
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+ if first_atom == "*":
+ if not isinstance(model, torch.nn.ModuleList):
+ raise ValueError("Wildcard '*' can only be used with ModuleList")
+ submodules = []
+ for submodule in model:
+ subsubmodules = _find_submodule_by_name(submodule, remaining_name)
+ if not isinstance(subsubmodules, list):
+ subsubmodules = [subsubmodules]
+ submodules.extend(subsubmodules)
+ return submodules
+ else:
+ if hasattr(model, first_atom):
+ submodule = getattr(model, first_atom)
+ return _find_submodule_by_name(submodule, remaining_name)
+ else:
+ raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py
index 0e3082eada..067d876ffc 100644
--- a/src/diffusers/image_processor.py
+++ b/src/diffusers/image_processor.py
@@ -1045,16 +1045,39 @@ class VaeImageProcessorLDM3D(VaeImageProcessor):
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
r"""
Convert an RGB-like depth image to a depth map.
-
- Args:
- image (`Union[np.ndarray, torch.Tensor]`):
- The RGB-like depth image to convert.
-
- Returns:
- `Union[np.ndarray, torch.Tensor]`:
- The corresponding depth map.
"""
- return image[:, :, 1] * 2**8 + image[:, :, 2]
+ # 1. Cast the tensor to a larger integer type (e.g., int32)
+ # to safely perform the multiplication by 256.
+ # 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
+ # 3. Cast the final result to the desired depth map type (uint16) if needed
+ # before returning, though leaving it as int32/int64 is often safer
+ # for return value from a library function.
+
+ if isinstance(image, torch.Tensor):
+ # Cast to a safe dtype (e.g., int32 or int64) for the calculation
+ original_dtype = image.dtype
+ image_safe = image.to(torch.int32)
+
+ # Calculate the depth map
+ depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
+
+ # You may want to cast the final result to uint16, but casting to a
+ # larger int type (like int32) is sufficient to fix the overflow.
+ # depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
+ return depth_map.to(original_dtype)
+
+ elif isinstance(image, np.ndarray):
+ # NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
+ original_dtype = image.dtype
+ image_safe = image.astype(np.int32)
+
+ # Calculate the depth map
+ depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
+
+ # depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
+ return depth_map.astype(original_dtype)
+ else:
+ raise TypeError("Input image must be a torch.Tensor or np.ndarray")
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
r"""
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 7425486538..48507aae03 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -77,6 +77,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
+ "KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
"QwenImageLoraLoaderMixin",
@@ -115,6 +116,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
+ KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index d18c82df4f..3d75a7d875 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -544,11 +544,7 @@ class LoraBaseMixin:
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -628,11 +624,7 @@ class LoraBaseMixin:
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -1064,6 +1056,41 @@ class LoraBaseMixin:
save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}")
+ @classmethod
+ def _save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ lora_layers: Dict[str, Dict[str, Union[torch.nn.Module, torch.Tensor]]],
+ lora_metadata: Dict[str, Optional[dict]],
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ ):
+ """
+ Helper method to pack and save LoRA weights and metadata. This method centralizes the saving logic for all
+ pipeline types.
+ """
+ state_dict = {}
+ final_lora_adapter_metadata = {}
+
+ for prefix, layers in lora_layers.items():
+ state_dict.update(cls.pack_weights(layers, prefix))
+
+ for prefix, metadata in lora_metadata.items():
+ if metadata:
+ final_lora_adapter_metadata.update(_pack_dict_with_prefix(metadata, prefix))
+
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ lora_adapter_metadata=final_lora_adapter_metadata if final_lora_adapter_metadata else None,
+ )
+
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 6f584a5f0e..2807416f97 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -558,70 +558,62 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
ait_sd[target_key] = value
if any("guidance_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.guidance_embedder.linear_1.{lora_key}.weight",
- "lora_unet_guidance_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.guidance_embedder.linear_2.{lora_key}.weight",
- "lora_unet_guidance_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_in_layer",
+ "time_text_embed.guidance_embedder.linear_1",
+ )
+
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_guidance_in_out_layer",
+ "time_text_embed.guidance_embedder.linear_2",
)
if any("img_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- ("x_embedder.{lora_key}.weight", "lora_unet_img_in.{orig_lora_key}.weight", None),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_img_in",
+ "x_embedder",
)
if any("txt_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- ("context_embedder.{lora_key}.weight", "lora_unet_txt_in.{orig_lora_key}.weight", None),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_txt_in",
+ "context_embedder",
)
if any("time_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.timestep_embedder.linear_1.{lora_key}.weight",
- "lora_unet_time_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.timestep_embedder.linear_2.{lora_key}.weight",
- "lora_unet_time_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_time_in_in_layer",
+ "time_text_embed.timestep_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_time_in_out_layer",
+ "time_text_embed.timestep_embedder.linear_2",
)
if any("vector_in" in k for k in sds_sd):
- assign_remaining_weights(
- [
- (
- "time_text_embed.text_embedder.linear_1.{lora_key}.weight",
- "lora_unet_vector_in_in_layer.{orig_lora_key}.weight",
- None,
- ),
- (
- "time_text_embed.text_embedder.linear_2.{lora_key}.weight",
- "lora_unet_vector_in_out_layer.{orig_lora_key}.weight",
- None,
- ),
- ],
+ _convert_to_ai_toolkit(
sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_in_layer",
+ "time_text_embed.text_embedder.linear_1",
+ )
+ _convert_to_ai_toolkit(
+ sds_sd,
+ ait_sd,
+ "lora_unet_vector_in_out_layer",
+ "time_text_embed.text_embedder.linear_2",
)
if any("final_layer" in k for k in sds_sd):
@@ -1985,14 +1977,34 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
"time_projection.1.diff_b"
)
- if any("head.head" in k for k in state_dict):
- converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
- f"head.head.{lora_down_key}.weight"
- )
- converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(f"head.head.{lora_up_key}.weight")
+ if any("head.head" in k for k in original_state_dict):
+ if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
+ f"head.head.{lora_down_key}.weight"
+ )
+ if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
+ converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
+ f"head.head.{lora_up_key}.weight"
+ )
if "head.head.diff_b" in original_state_dict:
converted_state_dict["proj_out.lora_B.bias"] = original_state_dict.pop("head.head.diff_b")
+ # Notes: https://huggingface.co/lightx2v/Wan2.2-Distill-Loras
+ # This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
+ # Since for this particular LoRA, we don't have the corresponding up matrix, I will use
+ # an identity.
+ if any("head.head" in k and k.endswith(".diff") for k in state_dict):
+ if f"head.head.{lora_down_key}.weight" in state_dict:
+ logger.info(
+ f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."
+ )
+ converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop("head.head.diff")
+ down_matrix_head = converted_state_dict["proj_out.lora_A.weight"]
+ up_matrix_shape = (down_matrix_head.shape[0], converted_state_dict["proj_out.lora_B.bias"].shape[0])
+ converted_state_dict["proj_out.lora_B.weight"] = torch.eye(
+ *up_matrix_shape, dtype=down_matrix_head.dtype, device=down_matrix_head.device
+ ).T
+
for text_time in ["text_embedding", "time_embedding"]:
if any(text_time in k for k in original_state_dict):
for b_n in [0, 2]:
@@ -2201,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
state_dict = {convert_key(k): v for k, v in state_dict.items()}
+ has_default = any("default." in k for k in state_dict)
+ if has_default:
+ state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
+
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 7e89066f1f..25919a896a 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -246,13 +246,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -510,35 +505,28 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- if unet_lora_adapter_metadata:
- lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `unet_lora_layers` or `text_encoder_lora_layers`.")
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -552,11 +540,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
@@ -593,11 +577,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -628,33 +608,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is
- loaded into `self.unet`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state
- dict is loaded into `self.text_encoder`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -731,13 +685,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
r"""
Return state dict for lora weights and the network alphas.
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
+ > [!WARNING] > We support loading A1111 formatted LoRA checkpoints in a limited capacity. > > This function is
+ experimental and might change in the future.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
@@ -974,74 +923,36 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `unet`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- unet_lora_adapter_metadata:
- LoRA adapter metadata associated with the unet to be serialized with the state dict.
- text_encoder_lora_adapter_metadata:
- LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
- text_encoder_2_lora_adapter_metadata:
- LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if unet_lora_layers:
- state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
+ lora_layers[cls.unet_name] = unet_lora_layers
+ lora_metadata[cls.unet_name] = unet_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- if unet_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
-
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
- if text_encoder_2_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
- )
-
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -1053,35 +964,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1093,21 +976,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
def unfuse_lora(self, components: List[str] = ["unet", "text_encoder", "text_encoder_2"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1133,51 +1002,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1231,30 +1056,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
- `self.text_encoder`.
-
- All kwargs are forwarded to `self.lora_state_dict`.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
- loaded.
-
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1323,26 +1125,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SD3Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1437,76 +1220,36 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
text_encoder_2_lora_adapter_metadata=None,
):
r"""
- Save the LoRA parameters corresponding to the UNet and text encoder.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- text_encoder_2_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `text_encoder_2`. Must explicitly pass the text
- encoder LoRA state dict because it comes from 🤗 Transformers.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
- text_encoder_lora_adapter_metadata:
- LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
- text_encoder_2_lora_adapter_metadata:
- LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
- raise ValueError(
- "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
- )
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
+ lora_layers["text_encoder"] = text_encoder_lora_layers
+ lora_metadata["text_encoder"] = text_encoder_lora_adapter_metadata
if text_encoder_2_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
+ lora_layers["text_encoder_2"] = text_encoder_2_lora_layers
+ lora_metadata["text_encoder_2"] = text_encoder_2_lora_adapter_metadata
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
+ if not lora_layers:
+ raise ValueError(
+ "You must pass at least one of `transformer_lora_layers`, `text_encoder_lora_layers`, or `text_encoder_2_lora_layers`."
)
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- if text_encoder_2_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
- )
-
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
@@ -1519,35 +1262,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1560,21 +1275,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
- unfuse_text_encoder (`bool`, defaults to `True`):
- Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
- LoRA parameters then it won't have any effect.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1596,51 +1297,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -1695,25 +1352,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -1759,26 +1398,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`AuraFlowTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -1810,48 +1430,26 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -1864,35 +1462,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -1905,18 +1475,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -1943,50 +1502,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -2240,30 +1756,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`FluxTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -2435,37 +1928,28 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
- state_dict = {}
- lora_adapter_metadata = {}
-
- if not (transformer_lora_layers or text_encoder_lora_layers):
- raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
+ lora_layers = {}
+ lora_metadata = {}
if transformer_lora_layers:
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if text_encoder_lora_layers:
- state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
+ lora_layers[cls.text_encoder_name] = text_encoder_lora_layers
+ lora_metadata[cls.text_encoder_name] = text_encoder_lora_adapter_metadata
- if transformer_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if text_encoder_lora_adapter_metadata:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -2477,35 +1961,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
@@ -2533,11 +1989,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
Reverses the effect of
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
- This is an experimental API.
-
-
+ > [!WARNING] > This is an experimental API.
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
@@ -2848,30 +2300,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
hotswap: bool = False,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- network_alphas (`Dict[str, float]`):
- The value of the network alpha used for stable learning and preventing underflow. This value has the
- same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
- link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
- transformer (`UVit2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
@@ -3021,51 +2450,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3119,25 +2504,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3183,26 +2550,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogVideoXTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3222,7 +2570,6 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
)
@classmethod
- # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
@@ -3234,48 +2581,26 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
@@ -3287,35 +2612,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3327,18 +2624,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -3360,51 +2646,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3459,25 +2701,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3523,26 +2747,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`MochiTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3574,48 +2779,26 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3628,35 +2811,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -3669,18 +2824,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -3701,50 +2845,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -3803,25 +2904,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -3867,26 +2950,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`LTXVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -3918,48 +2982,26 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -3972,35 +3014,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4013,18 +3027,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4046,51 +3049,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4145,25 +3104,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4209,26 +3150,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SanaTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4260,48 +3182,26 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4314,35 +3214,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4355,18 +3227,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4387,50 +3248,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading original format HunyuanVideo LoRA checkpoints.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4489,25 +3307,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4553,26 +3353,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HunyuanVideoTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4604,48 +3385,26 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -4658,35 +3417,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -4699,18 +3430,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -4731,50 +3451,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -4834,25 +3511,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -4898,26 +3557,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`Lumina2Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -4949,48 +3589,26 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -5003,35 +3621,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5044,18 +3634,292 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
+ """
+ super().unfuse_lora(components=components, **kwargs)
-
- This is an experimental API.
+class KandinskyLoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`Kandinsky5Transformer3DModel`],
+ """
-
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+ - A string, the *model id* of a pretrained model hosted on the Hub.
+ - A path to a *directory* containing the model weights.
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository.
+ weight_name (`str`, *optional*, defaults to None):
+ Name of the serialized state dict file.
+ use_safetensors (`bool`, *optional*):
+ Whether to use safetensors for loading.
+ return_lora_metadata (`bool`, *optional*, defaults to False):
+ When enabled, additionally return the LoRA adapter metadata.
+ """
+ # Load the main state dict first which has the LoRA layers
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model.
+ hotswap (`bool`, *optional*):
+ Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ kwargs (`dict`, *optional*):
+ See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ # Load LoRA into transformer
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ Load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters.
+ transformer (`Kandinsky5Transformer3DModel`):
+ The transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata.
+ """
+ if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata=None,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the transformer and text encoders.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process.
+ save_function (`Callable`):
+ The function to use to save the state dictionary.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way.
+ transformer_lora_adapter_metadata:
+ LoRA adapter metadata associated with the transformer.
+ """
+ lora_layers = {}
+ lora_metadata = {}
+
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
+
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers`")
+
+ cls._save_lora_weights(
+ save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ )
+
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing.
+
+ Example:
+ ```py
+ from diffusers import Kandinsky5T2VPipeline
+
+ pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
+ pipeline.load_lora_weights("path/to/lora.safetensors")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of [`pipe.fuse_lora()`].
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5076,50 +3940,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -5225,25 +4046,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -5313,26 +4116,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`WanTransformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -5364,48 +4148,26 @@ class WanLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5418,35 +4180,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5459,18 +4193,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5492,50 +4215,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -5643,25 +4323,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -5731,26 +4393,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`SkyReelsV2Transformer3DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -5782,48 +4425,26 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -5836,35 +4457,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -5877,18 +4470,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -5910,51 +4492,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -6009,25 +4547,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6073,26 +4593,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`CogView4Transformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6124,48 +4625,26 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6178,35 +4657,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6219,18 +4670,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -6251,50 +4691,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -6353,25 +4750,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6417,26 +4796,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`HiDreamImageTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6468,48 +4828,26 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
@@ -6522,35 +4860,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6563,18 +4873,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
@@ -6595,51 +4894,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Return state dict for lora weights and the network alphas.
-
-
-
- We support loading A1111 formatted LoRA checkpoints in a limited capacity.
-
- This function is experimental and might change in the future.
-
-
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- Can be either:
-
- - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
- the Hub.
- - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
- with [`ModelMixin.save_pretrained`].
- - A [torch state
- dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
-
- cache_dir (`Union[str, os.PathLike]`, *optional*):
- Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
- is not used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- cached versions if they exist.
-
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
- local_files_only (`bool`, *optional*, defaults to `False`):
- Whether to only load local model weights and configuration files or not. If set to `True`, the model
- won't be downloaded from the Hub.
- token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
- `diffusers-cli login` (stored in `~/.huggingface`) is used.
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
- allowed by Git.
- subfolder (`str`, *optional*, defaults to `""`):
- The subfolder location of a model file within a larger model repository on the Hub or locally.
- return_lora_metadata (`bool`, *optional*, defaults to False):
- When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
-
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
@@ -6685,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
- if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
+ has_default = any("default." in k for k in state_dict)
+ if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default:
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
out = (state_dict, metadata) if return_lora_metadata else state_dict
@@ -6700,25 +4956,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
"""
- Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
- `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
- [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
- dict is loaded into `self.transformer`.
-
- Parameters:
- pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- kwargs (`dict`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -6764,26 +5002,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
metadata=None,
):
"""
- This will load the LoRA layers specified in `state_dict` into `transformer`.
-
- Parameters:
- state_dict (`dict`):
- A standard state dict containing the lora layer parameters. The keys can either be indexed directly
- into the unet or prefixed with an additional `unet` which can be used to distinguish between text
- encoder lora layers.
- transformer (`QwenImageTransformer2DModel`):
- The Transformer model to load the LoRA layers into.
- adapter_name (`str`, *optional*):
- Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
- `default_{i}` where i is the total number of adapters being loaded.
- low_cpu_mem_usage (`bool`, *optional*):
- Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
- weights.
- hotswap (`bool`, *optional*):
- See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
- metadata (`dict`):
- Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
- from the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
@@ -6815,48 +5034,26 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
- Save the LoRA parameters corresponding to the transformer.
-
- Arguments:
- save_directory (`str` or `os.PathLike`):
- Directory to save LoRA parameters to. Will be created if it doesn't exist.
- transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
- State dict of the LoRA layers corresponding to the `transformer`.
- is_main_process (`bool`, *optional*, defaults to `True`):
- Whether the process calling this is the main process or not. Useful during distributed training and you
- need to call this function on all processes. In this case, set `is_main_process=True` only on the main
- process to avoid race conditions.
- save_function (`Callable`):
- The function to use to save the state dictionary. Useful during distributed training when you need to
- replace `torch.save` with another method. Can be configured with the environment variable
- `DIFFUSERS_SAVE_MODE`.
- safe_serialization (`bool`, *optional*, defaults to `True`):
- Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
- transformer_lora_adapter_metadata:
- LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
- state_dict = {}
- lora_adapter_metadata = {}
+ lora_layers = {}
+ lora_metadata = {}
- if not transformer_lora_layers:
- raise ValueError("You must pass `transformer_lora_layers`.")
+ if transformer_lora_layers:
+ lora_layers[cls.transformer_name] = transformer_lora_layers
+ lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
- state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+ if not lora_layers:
+ raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
- if transformer_lora_adapter_metadata is not None:
- lora_adapter_metadata.update(
- _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
- )
-
- # Save the model
- cls.write_lora_layers(
- state_dict=state_dict,
+ cls._save_lora_weights(
save_directory=save_directory,
+ lora_layers=lora_layers,
+ lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
- lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
@@ -6869,35 +5066,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
**kwargs,
):
r"""
- Fuses the LoRA parameters into the original parameters of the corresponding blocks.
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
- lora_scale (`float`, defaults to 1.0):
- Controls how much to influence the outputs with the LoRA parameters.
- safe_fusing (`bool`, defaults to `False`):
- Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
- adapter_names (`List[str]`, *optional*):
- Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
-
- Example:
-
- ```py
- from diffusers import DiffusionPipeline
- import torch
-
- pipeline = DiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
- ).to("cuda")
- pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
- pipeline.fuse_lora(lora_scale=0.7)
- ```
+ See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
@@ -6910,18 +5079,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
- Reverses the effect of
- [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
-
-
-
- This is an experimental API.
-
-
-
- Args:
- components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
- unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 2381ccfef3..7d65b30659 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -293,7 +293,7 @@ class PeftAdapterMixin:
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
- if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
+ if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 49ac2a1c56..202e77fd19 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -25,6 +25,7 @@ from ..utils import (
_import_structure = {}
if is_torch_available():
+ _import_structure["_modeling_parallel"] = ["ContextParallelConfig", "ParallelConfig"]
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
@@ -35,6 +36,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_cosmos"] = ["AutoencoderKLCosmos"]
_import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
+ _import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
@@ -81,6 +84,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
_import_structure["transformers.transformer_bria"] = ["BriaTransformer2DModel"]
+ _import_structure["transformers.transformer_bria_fibo"] = ["BriaFiboTransformer2DModel"]
_import_structure["transformers.transformer_chroma"] = ["ChromaTransformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"]
@@ -90,11 +94,15 @@ if is_torch_available():
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
+ _import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
+ _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
+ _import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
+ _import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -119,6 +127,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
+ from ._modeling_parallel import ContextParallelConfig, ParallelConfig
from .adapter import MultiAdapter, T2IAdapter
from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
@@ -129,6 +138,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLAllegro,
AutoencoderKLCogVideoX,
AutoencoderKLCosmos,
+ AutoencoderKLHunyuanImage,
+ AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
@@ -165,6 +176,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AllegroTransformer3DModel,
AuraFlowTransformer2DModel,
+ BriaFiboTransformer2DModel,
BriaTransformer2DModel,
ChromaTransformer2DModel,
CogVideoXTransformer3DModel,
@@ -178,8 +190,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
+ HunyuanImageTransformer2DModel,
HunyuanVideoFramepackTransformer3DModel,
HunyuanVideoTransformer3DModel,
+ Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
@@ -188,8 +202,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
+ PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
+ SanaVideoTransformer3DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
diff --git a/src/diffusers/models/_modeling_parallel.py b/src/diffusers/models/_modeling_parallel.py
new file mode 100644
index 0000000000..2a1d2cc6ce
--- /dev/null
+++ b/src/diffusers/models/_modeling_parallel.py
@@ -0,0 +1,241 @@
+# 🚨🚨🚨 Experimental parallelism support for Diffusers 🚨🚨🚨
+# Experimental changes are subject to change and APIs may break without warning.
+
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+
+
+if TYPE_CHECKING:
+ pass
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(aryan): add support for the following:
+# - Unified Attention
+# - More dispatcher attention backends
+# - CFG/Data Parallel
+# - Tensor Parallel
+
+
+@dataclass
+class ContextParallelConfig:
+ """
+ Configuration for context parallelism.
+
+ Args:
+ ring_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
+ total number of devices in the context parallel mesh.
+ ulysses_degree (`int`, *optional*, defaults to `1`):
+ Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
+ total number of devices in the context parallel mesh.
+ convert_to_fp32 (`bool`, *optional*, defaults to `True`):
+ Whether to convert output and LSE to float32 for ring attention numerical stability.
+ rotate_method (`str`, *optional*, defaults to `"allgather"`):
+ Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
+ is supported.
+
+ """
+
+ ring_degree: Optional[int] = None
+ ulysses_degree: Optional[int] = None
+ convert_to_fp32: bool = True
+ # TODO: support alltoall
+ rotate_method: Literal["allgather", "alltoall"] = "allgather"
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
+ _ring_local_rank: int = None
+ _ulysses_local_rank: int = None
+
+ def __post_init__(self):
+ if self.ring_degree is None:
+ self.ring_degree = 1
+ if self.ulysses_degree is None:
+ self.ulysses_degree = 1
+
+ def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._mesh = mesh
+ if self.ring_degree is None:
+ self.ring_degree = 1
+ if self.ulysses_degree is None:
+ self.ulysses_degree = 1
+ if self.rotate_method != "allgather":
+ raise NotImplementedError(
+ f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
+ )
+ if self._flattened_mesh is None:
+ self._flattened_mesh = self._mesh._flatten()
+ if self._ring_mesh is None:
+ self._ring_mesh = self._mesh["ring"]
+ if self._ulysses_mesh is None:
+ self._ulysses_mesh = self._mesh["ulysses"]
+ if self._ring_local_rank is None:
+ self._ring_local_rank = self._ring_mesh.get_local_rank()
+ if self._ulysses_local_rank is None:
+ self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
+
+
+@dataclass
+class ParallelConfig:
+ """
+ Configuration for applying different parallelisms.
+
+ Args:
+ context_parallel_config (`ContextParallelConfig`, *optional*):
+ Configuration for context parallelism.
+ """
+
+ context_parallel_config: Optional[ContextParallelConfig] = None
+
+ _rank: int = None
+ _world_size: int = None
+ _device: torch.device = None
+ _cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
+
+ def setup(
+ self,
+ rank: int,
+ world_size: int,
+ device: torch.device,
+ *,
+ cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
+ ):
+ self._rank = rank
+ self._world_size = world_size
+ self._device = device
+ self._cp_mesh = cp_mesh
+ if self.context_parallel_config is not None:
+ self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
+
+
+@dataclass(frozen=True)
+class ContextParallelInput:
+ """
+ Configuration for splitting an input tensor across context parallel region.
+
+ Args:
+ split_dim (`int`):
+ The dimension along which to split the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before splitting.
+ split_output (`bool`, *optional*, defaults to `False`):
+ Whether to split the output tensor of the layer along the given `split_dim` instead of the input tensor.
+ This is useful for layers whose outputs should be split after it does some preprocessing on the inputs (ex:
+ RoPE).
+ """
+
+ split_dim: int
+ expected_dims: Optional[int] = None
+ split_output: bool = False
+
+ def __repr__(self):
+ return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
+
+
+@dataclass(frozen=True)
+class ContextParallelOutput:
+ """
+ Configuration for gathering an output tensor across context parallel region.
+
+ Args:
+ gather_dim (`int`):
+ The dimension along which to gather the tensor.
+ expected_dims (`int`, *optional*):
+ The expected number of dimensions of the tensor. If provided, a check will be performed to ensure that the
+ tensor has the expected number of dimensions before gathering.
+ """
+
+ gather_dim: int
+ expected_dims: Optional[int] = None
+
+ def __repr__(self):
+ return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
+
+
+# A dictionary where keys denote the input to be split across context parallel region, and the
+# value denotes the sharding configuration.
+# If the key is a string, it denotes the name of the parameter in the forward function.
+# If the key is an integer, split_output must be set to True, and it denotes the index of the output
+# to be split across context parallel region.
+ContextParallelInputType = Dict[
+ Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
+]
+
+# A dictionary where keys denote the output to be gathered across context parallel region, and the
+# value denotes the gathering configuration.
+ContextParallelOutputType = Union[
+ ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
+]
+
+# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
+# the module should be split/gathered across context parallel region.
+ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
+
+
+# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
+#
+# Each model should define a _cp_plan attribute that contains information on how to shard/gather
+# tensors at different stages of the forward:
+#
+# ```python
+# _cp_plan = {
+# "": {
+# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+# },
+# "pos_embed": {
+# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# },
+# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+# }
+# ```
+#
+# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
+# split/gathered according to this at the respective module level. Here, the following happens:
+# - "":
+# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
+# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
+# - "pos_embed":
+# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
+# we can individually specify how they should be split
+# - "proj_out":
+# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
+# layer forward has run).
+#
+# ContextParallelInput:
+# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
+#
+# ContextParallelOutput:
+# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index c720b37955..5164cf311d 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -111,11 +111,7 @@ class AttentionMixin:
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
for module in self.modules():
if isinstance(module, AttentionModuleMixin):
@@ -241,7 +237,7 @@ class AttentionModuleMixin:
op_fw, op_bw = attention_op
dtype, *_ = op_fw.SUPPORTED_DTYPES
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
- _ = xops.memory_efficient_attention(q, q, q)
+ _ = xops.ops.memory_efficient_attention(q, q, q)
except Exception as e:
raise e
@@ -674,7 +670,7 @@ class JointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
index f71be7c8ec..c17a3d0ed6 100644
--- a/src/diffusers/models/attention_dispatch.py
+++ b/src/diffusers/models/attention_dispatch.py
@@ -17,12 +17,18 @@ import functools
import inspect
import math
from enum import Enum
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
import torch
+
+if torch.distributed.is_available():
+ import torch.distributed._functional_collectives as funcol
+
from ..utils import (
get_logger,
+ is_aiter_available,
+ is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
@@ -39,7 +45,11 @@ from ..utils import (
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
+if TYPE_CHECKING:
+ from ._modeling_parallel import ParallelConfig
+
_REQUIRED_FLASH_VERSION = "2.6.3"
+_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
@@ -47,6 +57,7 @@ _REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
+_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
@@ -56,9 +67,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
+ from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
+ _wrapped_flash_attn_backward = None
+ _wrapped_flash_attn_forward = None
if _CAN_USE_FLASH_ATTN_3:
@@ -68,6 +82,12 @@ else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
+
+if _CAN_USE_AITER_ATTN:
+ from aiter import flash_attn_func as aiter_flash_attn_func
+else:
+ aiter_flash_attn_func = None
+
if DIFFUSERS_ENABLE_HUB_KERNELS:
if not is_kernels_available():
raise ImportError(
@@ -168,6 +188,9 @@ class AttentionBackendName(str, Enum):
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
+ # `aiter`
+ AITER = "aiter"
+
# PyTorch native
FLEX = "flex"
NATIVE = "native"
@@ -197,17 +220,24 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
+ _supports_context_parallel = {}
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@classmethod
- def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
+ def register(
+ cls,
+ backend: AttentionBackendName,
+ constraints: Optional[List[Callable]] = None,
+ supports_context_parallel: bool = False,
+ ):
logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
def decorator(func):
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ cls._supports_context_parallel[backend] = supports_context_parallel
return func
return decorator
@@ -220,6 +250,17 @@ class _AttentionBackendRegistry:
def list_backends(cls):
return list(cls._backends.keys())
+ @classmethod
+ def _is_context_parallel_enabled(
+ cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
+ ) -> bool:
+ supports_context_parallel = backend in cls._supports_context_parallel
+ is_degree_greater_than_1 = parallel_config is not None and (
+ parallel_config.context_parallel_config.ring_degree > 1
+ or parallel_config.context_parallel_config.ulysses_degree > 1
+ )
+ return supports_context_parallel and is_degree_greater_than_1
+
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
@@ -253,6 +294,7 @@ def dispatch_attention_fn(
attention_kwargs: Optional[Dict[str, Any]] = None,
*,
backend: Optional[AttentionBackendName] = None,
+ parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
attention_kwargs = attention_kwargs or {}
@@ -264,6 +306,14 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+ if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
+ backend_name, parallel_config
+ ):
+ raise ValueError(
+ f"Backend {backend_name} either does not support context parallelism or context parallelism "
+ f"was enabled with a world size of 1."
+ )
+
kwargs = {
"query": query,
"key": key,
@@ -273,6 +323,7 @@ def dispatch_attention_fn(
"is_causal": is_causal,
"scale": scale,
**attention_kwargs,
+ "_parallel_config": parallel_config,
}
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa
@@ -376,6 +427,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
+ elif backend == AttentionBackendName.AITER:
+ if not _CAN_USE_AITER_ATTN:
+ raise RuntimeError(
+ f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
+ )
+
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -521,22 +578,701 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
-
-
-@_custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
-def _wrapped_flash_attn_3_original(
- query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
- out, lse = flash_attn_3_func(query, key, value)
+ # Hardcoded for now because pytorch does not support tuple/int type hints
+ window_size = (-1, -1)
+ out, lse, *_ = flash_attn_3_func(
+ q=q,
+ k=k,
+ v=v,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ qv=qv,
+ q_descale=q_descale,
+ k_descale=k_descale,
+ v_descale=v_descale,
+ window_size=window_size,
+ attention_chunk=attention_chunk,
+ softcap=softcap,
+ num_splits=num_splits,
+ pack_gqa=pack_gqa,
+ deterministic=deterministic,
+ sm_margin=sm_margin,
+ )
lse = lse.permute(0, 2, 1)
return out, lse
-@_register_fake("flash_attn_3::_flash_attn_forward")
-def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- batch_size, seq_len, num_heads, head_dim = query.shape
+@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
+def _(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ softmax_scale: Optional[float] = None,
+ causal: bool = False,
+ qv: Optional[torch.Tensor] = None,
+ q_descale: Optional[torch.Tensor] = None,
+ k_descale: Optional[torch.Tensor] = None,
+ v_descale: Optional[torch.Tensor] = None,
+ attention_chunk: int = 0,
+ softcap: float = 0.0,
+ num_splits: int = 1,
+ pack_gqa: Optional[bool] = None,
+ deterministic: bool = False,
+ sm_margin: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ window_size = (-1, -1) # noqa: F841
+ # A lot of the parameters here are not yet used in any way within diffusers.
+ # We can safely ignore for now and keep the fake op shape propagation simple.
+ batch_size, seq_len, num_heads, head_dim = q.shape
lse_shape = (batch_size, seq_len, num_heads)
- return torch.empty_like(query), query.new_empty(lse_shape)
+ return torch.empty_like(q), q.new_empty(lse_shape)
+
+
+# ===== Helper functions to use attention backends with templated CP autograd functions =====
+
+
+def _native_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ # Native attention does not return_lse
+ if return_lse:
+ raise ValueError("Native attention does not support return_lse=True")
+
+ # used for backward pass
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value)
+ ctx.attn_mask = attn_mask
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.enable_gqa = enable_gqa
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ return out
+
+
+def _native_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value = ctx.saved_tensors
+
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+
+ query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query_t,
+ key=key_t,
+ value=value_t,
+ attn_mask=ctx.attn_mask,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ enable_gqa=ctx.enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ grad_out_t = grad_out.permute(0, 2, 1, 3)
+ grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
+ outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
+ )
+
+ grad_query = grad_query_t.permute(0, 2, 1, 3)
+ grad_key = grad_key_t.permute(0, 2, 1, 3)
+ grad_value = grad_value_t.permute(0, 2, 1, 3)
+
+ return grad_query, grad_key, grad_value
+
+
+# https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958
+# forward declaration:
+# aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
+def _cudnn_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
+
+ tensors_to_save = ()
+
+ # Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
+ # if the input tensors are not contiguous.
+ query = query.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+ tensors_to_save += (query, key, value)
+
+ out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
+ torch.ops.aten._scaled_dot_product_cudnn_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_bias=attn_mask,
+ compute_log_sumexp=return_lse,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ return_debug_mask=False,
+ scale=scale,
+ )
+ )
+
+ tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
+ if _save_ctx:
+ ctx.save_for_backward(*tensors_to_save)
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.attn_mask = attn_mask
+ ctx.max_q = max_q
+ ctx.max_k = max_k
+
+ out = out.transpose(1, 2).contiguous()
+ if lse is not None:
+ lse = lse.transpose(1, 2).contiguous()
+ return (out, lse) if return_lse else out
+
+
+# backward declaration:
+# aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)
+def _cudnn_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
+
+ grad_out = grad_out.transpose(1, 2).contiguous()
+ key = key.transpose(1, 2).contiguous()
+ value = value.transpose(1, 2).contiguous()
+
+ # Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
+ grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ logsumexp=lse,
+ philox_seed=philox_seed,
+ philox_offset=philox_offset,
+ attn_bias=ctx.attn_mask,
+ cum_seq_q=cum_seq_q,
+ cum_seq_k=cum_seq_k,
+ max_q=ctx.max_q,
+ max_k=ctx.max_k,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ )
+ grad_query, grad_key, grad_value = (x.transpose(1, 2).contiguous() for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value
+
+
+# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
+def _flash_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
+
+ # Hardcoded for now
+ window_size = (-1, -1)
+ softcap = 0.0
+ alibi_slopes = None
+ deterministic = False
+ grad_enabled = any(x.requires_grad for x in (query, key, value))
+
+ if scale is None:
+ scale = query.shape[-1] ** (-0.5)
+
+ # flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
+ if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
+ dropout_p = dropout_p if dropout_p > 0 else 1e-30
+
+ with torch.set_grad_enabled(grad_enabled):
+ out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
+ query,
+ key,
+ value,
+ dropout_p,
+ scale,
+ is_causal,
+ window_size[0],
+ window_size[1],
+ softcap,
+ alibi_slopes,
+ return_lse,
+ )
+ lse = lse.permute(0, 2, 1)
+
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value, out, lse, rng_state)
+ ctx.dropout_p = dropout_p
+ ctx.scale = scale
+ ctx.is_causal = is_causal
+ ctx.window_size = window_size
+ ctx.softcap = softcap
+ ctx.alibi_slopes = alibi_slopes
+ ctx.deterministic = deterministic
+
+ return (out, lse) if return_lse else out
+
+
+def _flash_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+):
+ query, key, value, out, lse, rng_state = ctx.saved_tensors
+ grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
+
+ lse_d = _wrapped_flash_attn_backward( # noqa: F841
+ grad_out,
+ query,
+ key,
+ value,
+ out,
+ lse,
+ grad_query,
+ grad_key,
+ grad_value,
+ ctx.dropout_p,
+ ctx.scale,
+ ctx.is_causal,
+ ctx.window_size[0],
+ ctx.window_size[1],
+ ctx.softcap,
+ ctx.alibi_slopes,
+ ctx.deterministic,
+ rng_state,
+ )
+
+ # Head dimension may have been padded
+ grad_query = grad_query[..., : grad_out.shape[-1]]
+ grad_key = grad_key[..., : grad_out.shape[-1]]
+ grad_value = grad_value[..., : grad_out.shape[-1]]
+
+ return grad_query, grad_key, grad_value
+
+
+def _sage_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for Sage attention.")
+ if dropout_p > 0.0:
+ raise ValueError("`dropout_p` is not yet supported for Sage attention.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
+
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ lse = None
+ if return_lse:
+ out, lse, *_ = out
+ lse = lse.permute(0, 2, 1)
+
+ return (out, lse) if return_lse else out
+
+
+def _sage_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+):
+ raise NotImplementedError("Backward pass is not implemented for Sage attention.")
+
+
+# ===== Context parallel =====
+
+
+# Reference:
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
+# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
+def _wait_tensor(tensor):
+ if isinstance(tensor, funcol.AsyncCollectiveTensor):
+ tensor = tensor.wait()
+ return tensor
+
+
+def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
+ shape = x.shape
+ # HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
+ # to benchmark triton codegen fails somewhere:
+ # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
+ # ValueError: Tensors must be contiguous
+ x = x.flatten()
+ x = funcol.all_to_all_single(x, None, None, group)
+ x = x.reshape(shape)
+ x = _wait_tensor(x)
+ return x
+
+
+class TemplatedRingAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ring_mesh = _parallel_config.context_parallel_config._ring_mesh
+ rank = _parallel_config.context_parallel_config._ring_local_rank
+ world_size = _parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ prev_out = prev_lse = None
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx.q_shape = query.shape
+ ctx.kv_shape = key.shape
+ ctx._parallel_config = _parallel_config
+
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ out, lse = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ True,
+ _save_ctx=i == 0,
+ _parallel_config=_parallel_config,
+ )
+
+ if _parallel_config.context_parallel_config.convert_to_fp32:
+ out = out.to(torch.float32)
+ lse = lse.to(torch.float32)
+
+ lse = lse.unsqueeze(-1)
+ if prev_out is not None:
+ out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
+ lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
+ prev_out = out
+ prev_lse = lse
+
+ out = out.to(query.dtype)
+ lse = lse.squeeze(-1)
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ring_mesh = ctx._parallel_config.context_parallel_config._ring_mesh
+ rank = ctx._parallel_config.context_parallel_config._ring_local_rank
+ world_size = ctx._parallel_config.context_parallel_config.ring_degree
+ next_rank = (rank + 1) % world_size
+ next_ranks = list(range(1, world_size)) + [0]
+
+ accum_dtype = torch.float32 if ctx._parallel_config.context_parallel_config.convert_to_fp32 else grad_out.dtype
+ grad_query = torch.zeros(ctx.q_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_key = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ grad_value = torch.zeros(ctx.kv_shape, dtype=accum_dtype, device=grad_out.device)
+ next_grad_kv = None
+
+ query, key, value, *_ = ctx.saved_tensors
+ kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
+ kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
+ kv_buffer = kv_buffer.chunk(world_size)
+
+ for i in range(world_size):
+ if i > 0:
+ kv = kv_buffer[next_rank]
+ key_numel = key.numel()
+ key = kv[:key_numel].reshape_as(key)
+ value = kv[key_numel:].reshape_as(value)
+ next_rank = (next_rank + 1) % world_size
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ if i > 0:
+ grad_kv_buffer = _wait_tensor(next_grad_kv)
+ grad_key_numel = grad_key.numel()
+ grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
+ grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
+
+ grad_query += grad_query_op
+ grad_key += grad_key_op
+ grad_value += grad_value_op
+
+ if i < world_size - 1:
+ grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
+ next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
+
+ grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+class TemplatedUlyssesAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ world_size = _parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ B, S_Q_LOCAL, H, D = query.shape
+ _, S_KV_LOCAL, _, _ = key.shape
+ H_LOCAL = H // world_size
+ query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
+ query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=True,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ out = _all_to_all_single(out, group)
+ out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
+
+ if return_lse:
+ lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
+ lse = _all_to_all_single(lse, group)
+ lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
+ else:
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ ulysses_mesh = ctx._parallel_config.context_parallel_config._ulysses_mesh
+ world_size = ctx._parallel_config.context_parallel_config.ulysses_degree
+ group = ulysses_mesh.get_group()
+
+ B, S_LOCAL, H, D = grad_out.shape
+ H_LOCAL = H // world_size
+
+ grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ grad_out = _all_to_all_single(grad_out, group)
+ grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
+
+ grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx, grad_out)
+
+ grad_query, grad_key, grad_value = (
+ x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ for x in (grad_query_op, grad_key_op, grad_value_op)
+ )
+ grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value))
+ grad_query, grad_key, grad_value = (
+ x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
+ )
+
+ return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
+
+
+def _templated_context_parallel_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ *,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+):
+ if attn_mask is not None:
+ raise ValueError("Attention mask is not yet supported for templated attention.")
+ if is_causal:
+ raise ValueError("Causal attention is not yet supported for templated attention.")
+ if enable_gqa:
+ raise ValueError("GQA is not yet supported for templated attention.")
+
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
+ if _parallel_config.context_parallel_config.ring_degree > 1:
+ return TemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
+ return TemplatedUlyssesAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
# ===== Attention backends =====
@@ -545,34 +1281,50 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
- scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- deterministic: bool = False,
- return_attn_probs: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- out = flash_attn_func(
- q=query,
- k=key,
- v=value,
- dropout_p=dropout_p,
- softmax_scale=scale,
- causal=is_causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- deterministic=deterministic,
- return_attn_probs=return_attn_probs,
- )
- return out
+ lse = None
+ if _parallel_config is None:
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ dropout_p,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_flash_attention_forward_op,
+ backward_op=_flash_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -583,19 +1335,12 @@ def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- alibi_slopes: Optional[torch.Tensor] = None,
- deterministic: bool = False,
- return_attn_probs: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -603,16 +1348,11 @@ def _flash_varlen_attention(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -635,11 +1375,7 @@ def _flash_varlen_attention(
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
- window_size=window_size,
- softcap=softcap,
- alibi_slopes=alibi_slopes,
- deterministic=deterministic,
- return_attn_probs=return_attn_probs,
+ return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
@@ -656,30 +1392,17 @@ def _flash_attention_3(
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- deterministic: bool = False,
- return_attn_probs: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- out, lse, *_ = flash_attn_3_func(
+ out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
- qv=None,
- q_descale=None,
- k_descale=None,
- v_descale=None,
- window_size=window_size,
- attention_chunk=0,
- softcap=softcap,
- num_splits=1,
- pack_gqa=None,
- deterministic=deterministic,
- sm_margin=0,
)
- return (out, lse) if return_attn_probs else out
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -696,6 +1419,7 @@ def _flash_attention_3_hub(
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
out = flash_attn_3_func_hub(
q=query,
@@ -728,17 +1452,11 @@ def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
- window_size: Tuple[int, int] = (-1, -1),
- softcap: float = 0.0,
- deterministic: bool = False,
- return_attn_probs: bool = False,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
@@ -746,16 +1464,11 @@ def _flash_varlen_attention_3(
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -775,24 +1488,53 @@ def _flash_varlen_attention_3(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
- seqused_q=None,
- seqused_k=None,
softmax_scale=scale,
causal=is_causal,
- qv=None,
- q_descale=None,
- k_descale=None,
- v_descale=None,
- window_size=window_size,
- softcap=softcap,
- num_splits=1,
- pack_gqa=None,
- deterministic=deterministic,
- sm_margin=0,
)
out = out.unflatten(0, (batch_size, -1))
- return (out, lse) if return_attn_probs else out
+ return (out, lse) if return_lse else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.AITER,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _aiter_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ if not return_lse and torch.is_grad_enabled():
+ # aiter requires return_lse=True by assertion when gradients are enabled.
+ out, lse, *_ = aiter_flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_lse=True,
+ )
+ else:
+ out = aiter_flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -808,7 +1550,7 @@ def _native_flex_attention(
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
- kernel_options: Optional[Dict[str, Any]] = None,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
# TODO: should we LRU cache the block mask creation?
score_mod = None
@@ -853,7 +1595,6 @@ def _native_flex_attention(
scale=scale,
enable_gqa=enable_gqa,
return_lse=return_lse,
- kernel_options=kernel_options,
)
out = out.permute(0, 2, 1, 3)
return out
@@ -862,6 +1603,7 @@ def _native_flex_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
+ supports_context_parallel=True,
)
def _native_attention(
query: torch.Tensor,
@@ -872,38 +1614,13 @@ def _native_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- out = torch.nn.functional.scaled_dot_product_attention(
- query=query,
- key=key,
- value=value,
- attn_mask=attn_mask,
- dropout_p=dropout_p,
- is_causal=is_causal,
- scale=scale,
- enable_gqa=enable_gqa,
- )
- out = out.permute(0, 2, 1, 3)
- return out
-
-
-@_AttentionBackendRegistry.register(
- AttentionBackendName._NATIVE_CUDNN,
- constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
-)
-def _native_cudnn_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: Optional[float] = None,
- enable_gqa: bool = False,
-) -> torch.Tensor:
- query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ if return_lse:
+ raise ValueError("Native attention backend does not support setting `return_lse=True`.")
+ if _parallel_config is None:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
@@ -914,10 +1631,79 @@ def _native_cudnn_attention(
scale=scale,
enable_gqa=enable_gqa,
)
- out = out.permute(0, 2, 1, 3)
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_native_attention_forward_op,
+ backward_op=_native_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+
return out
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+)
+def _native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None and not return_lse:
+ query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_cudnn_attention_forward_op,
+ backward_op=_cudnn_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_EFFICIENT,
constraints=[_check_device, _check_shape],
@@ -931,7 +1717,11 @@ def _native_efficient_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native efficient attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -960,7 +1750,11 @@ def _native_flash_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native flash attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -990,7 +1784,11 @@ def _native_math_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native math attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = torch.nn.functional.scaled_dot_product_attention(
@@ -1017,7 +1815,11 @@ def _native_npu_attention(
value: torch.Tensor,
dropout_p: float = 0.0,
scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
@@ -1047,7 +1849,11 @@ def _native_xla_attention(
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
query = query / math.sqrt(query.shape[-1])
out = xla_flash_attention(
@@ -1063,6 +1869,7 @@ def _native_xla_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
)
def _sage_attention(
query: torch.Tensor,
@@ -1071,16 +1878,40 @@ def _sage_attention(
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
- return sageattn(
- q=query,
- k=key,
- v=value,
- tensor_layout="NHD",
- is_causal=is_causal,
- sm_scale=scale,
- return_lse=return_lse,
- )
+ lse = None
+ if _parallel_config is None:
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ 0.0,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_sage_attention_forward_op,
+ backward_op=_sage_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
@@ -1091,31 +1922,26 @@ def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
- cu_seqlens_q: Optional[torch.Tensor] = None,
- cu_seqlens_k: Optional[torch.Tensor] = None,
- max_seqlen_q: Optional[int] = None,
- max_seqlen_k: Optional[int] = None,
+ attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
- smooth_k: bool = True,
- attn_mask: Optional[torch.Tensor] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
+
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
- if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
- (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
- _prepare_for_flash_attn_or_sage_varlen(
- batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
- )
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
- else:
- seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
- cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
- cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+ )
key_valid, value_valid = [], []
for b in range(batch_size):
@@ -1137,7 +1963,6 @@ def _sage_varlen_attention(
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
- smooth_k=smooth_k,
)
out = out.unflatten(0, (batch_size, -1))
@@ -1154,11 +1979,8 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
- smooth_k: bool = True,
- smooth_v: bool = False,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
@@ -1166,11 +1988,7 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
- smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1185,10 +2003,8 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
- smooth_k: bool = True,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
@@ -1196,10 +2012,7 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1214,11 +2027,8 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
- pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
- smooth_k: bool = True,
- smooth_v: bool = False,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
@@ -1226,11 +2036,7 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
v=value,
tensor_layout="NHD",
is_causal=is_causal,
- qk_quant_gran=qk_quant_gran,
sm_scale=scale,
- pv_accum_dtype=pv_accum_dtype,
- smooth_k=smooth_k,
- smooth_v=smooth_v,
return_lse=return_lse,
)
@@ -1245,19 +2051,16 @@ def _sage_qk_int8_pv_fp16_triton_attention(
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
- quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
- smooth_k: bool = True,
return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="NHD",
- quantization_backend=quantization_backend,
is_causal=is_causal,
sm_scale=scale,
- smooth_k=smooth_k,
return_lse=return_lse,
)
@@ -1275,7 +2078,12 @@ def _xformers_attention(
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("xformers attention backend does not support setting `return_lse=True`.")
+
batch_size, seq_len_q, num_heads_q, _ = query.shape
_, seq_len_kv, num_heads_kv, _ = key.shape
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 990245de17..66455d733a 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -3669,11 +3669,7 @@ class FusedAttnProcessor2_0:
fused projection layers. For self-attention modules, all projection matrices (i.e., query, key, value) are fused.
For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is currently 🧪 experimental in nature and can change in future.
-
-
+ > [!WARNING] > This API is currently 🧪 experimental in nature and can change in future.
"""
def __init__(self):
diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py
index bfe386f1f6..c96b4fa88c 100644
--- a/src/diffusers/models/auto_model.py
+++ b/src/diffusers/models/auto_model.py
@@ -19,6 +19,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import logging
+from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
logger = logging.get_logger(__name__)
@@ -114,48 +115,45 @@ class AutoModel(ConfigMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
+ trust_remote_cocde (`bool`, *optional*, defaults to `False`):
+ Whether to trust remote code
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
Example:
```py
from diffusers import AutoModel
- unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ unet = AutoModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
"""
- cache_dir = kwargs.pop("cache_dir", None)
- force_download = kwargs.pop("force_download", False)
- proxies = kwargs.pop("proxies", None)
- token = kwargs.pop("token", None)
- local_files_only = kwargs.pop("local_files_only", False)
- revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
+ trust_remote_code = kwargs.pop("trust_remote_code", False)
- load_config_kwargs = {
- "cache_dir": cache_dir,
- "force_download": force_download,
- "proxies": proxies,
- "token": token,
- "local_files_only": local_files_only,
- "revision": revision,
- }
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "proxies",
+ "revision",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
+
+ # load_config_kwargs uses the same hub kwargs minus subfolder and resume_download
+ load_config_kwargs = {k: v for k, v in hub_kwargs.items() if k not in ["subfolder"]}
library = None
orig_class_name = None
@@ -189,15 +187,34 @@ class AutoModel(ConfigMixin):
else:
raise ValueError(f"Couldn't find model associated with the config file at {pretrained_model_or_path}.")
- from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
+ has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
+ trust_remote_code = resolve_trust_remote_code(trust_remote_code, pretrained_model_or_path, has_remote_code)
+ if not has_remote_code and trust_remote_code:
+ raise ValueError(
+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
+ )
- model_cls, _ = get_class_obj_and_candidates(
- library_name=library,
- class_name=orig_class_name,
- importable_classes=ALL_IMPORTABLE_CLASSES,
- pipelines=None,
- is_pipeline_module=False,
- )
+ if has_remote_code and trust_remote_code:
+ class_ref = config["auto_map"][cls.__name__]
+ module_file, class_name = class_ref.split(".")
+ module_file = module_file + ".py"
+ model_cls = get_class_from_dynamic_module(
+ pretrained_model_or_path,
+ subfolder=subfolder,
+ module_file=module_file,
+ class_name=class_name,
+ **hub_kwargs,
+ )
+ else:
+ from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
+
+ model_cls, _ = get_class_obj_and_candidates(
+ library_name=library,
+ class_name=orig_class_name,
+ importable_classes=ALL_IMPORTABLE_CLASSES,
+ pipelines=None,
+ is_pipeline_module=False,
+ )
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py
index c008a45298..edfaabb070 100644
--- a/src/diffusers/models/autoencoders/__init__.py
+++ b/src/diffusers/models/autoencoders/__init__.py
@@ -5,6 +5,8 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_cosmos import AutoencoderKLCosmos
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
+from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
+from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
diff --git a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
index 54b1fc6771..fa49fcfe79 100644
--- a/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_asym_kl.py
@@ -20,10 +20,10 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder, MaskConditionDecoder
-class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
+class AsymmetricAutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Designing a Better Asymmetric VQGAN for StableDiffusion https://huggingface.co/papers/2306.04632 . A VAE model with
KL loss for encoding images into latents and decoding latent representations into images.
@@ -107,9 +107,6 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
- self.use_slicing = False
- self.use_tiling = False
-
self.register_to_config(block_out_channels=up_block_out_channels)
self.register_to_config(force_upcast=False)
diff --git a/src/diffusers/models/autoencoders/autoencoder_dc.py b/src/diffusers/models/autoencoders/autoencoder_dc.py
index d3f31de854..724ec3bb76 100644
--- a/src/diffusers/models/autoencoders/autoencoder_dc.py
+++ b/src/diffusers/models/autoencoders/autoencoder_dc.py
@@ -27,7 +27,7 @@ from ..attention_processor import SanaMultiscaleLinearAttention
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm, get_normalization
from ..transformers.sana_transformer import GLUMBConv
-from .vae import DecoderOutput, EncoderOutput
+from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
class ResBlock(nn.Module):
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
return hidden_states
-class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderDC(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
An Autoencoder model introduced in [DCAE](https://huggingface.co/papers/2410.10733) and used in
[SANA](https://huggingface.co/papers/2410.10629).
@@ -536,27 +536,6 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
- def disable_tiling(self) -> None:
- r"""
- Disable tiled AE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced AE decoding. When this option is enabled, the AE will split the input tensor in slices to compute
- decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced AE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
@@ -617,7 +596,7 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
returned.
"""
if self.use_slicing and z.size(0) > 1:
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl.py b/src/diffusers/models/autoencoders/autoencoder_kl.py
index 9a4375a36b..1a72aa3cfe 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl.py
@@ -32,10 +32,10 @@ from ..attention_processor import (
)
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
-class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
+class AutoencoderKL(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
@@ -138,35 +138,6 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
- def enable_tiling(self, use_tiling: bool = True):
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- def disable_tiling(self):
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -532,11 +503,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -556,11 +523,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapter
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
index c24b8f42ac..6756586460 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py
@@ -28,6 +28,7 @@ from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..resnet import ResnetBlock2D
from ..upsampling import Upsample2D
+from .vae import AutoencoderMixin
class AllegroTemporalConvLayer(nn.Module):
@@ -673,7 +674,7 @@ class AllegroDecoder3D(nn.Module):
return sample
-class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
+class AutoencoderKLAllegro(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[Allegro](https://github.com/rhymes-ai/Allegro).
@@ -795,35 +796,6 @@ class AutoencoderKLAllegro(ModelMixin, ConfigMixin):
sample_size - self.tile_overlap_w,
)
- def enable_tiling(self) -> None:
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = True
-
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO(aryan)
# if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index e0e9436e89..5096b725d0 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -29,7 +29,7 @@ from ..downsampling import CogVideoXDownsample3D
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..upsampling import CogVideoXUpsample3D
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -955,7 +955,7 @@ class CogVideoXDecoder3D(nn.Module):
return hidden_states, new_conv_cache
-class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLCogVideoX(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[CogVideoX](https://github.com/THUDM/CogVideo).
@@ -1124,27 +1124,6 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
index 500e316ebc..b17522d1c4 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
@@ -24,7 +24,7 @@ from ...utils import get_logger
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, IdentityDistribution
+from .vae import AutoencoderMixin, DecoderOutput, IdentityDistribution
logger = get_logger(__name__)
@@ -875,7 +875,7 @@ class CosmosDecoder3d(nn.Module):
return hidden_states
-class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
+class AutoencoderKLCosmos(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
Autoencoder used in [Cosmos](https://huggingface.co/papers/2501.03575).
@@ -1031,27 +1031,6 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
enc = self.quant_conv(x)
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
index 7b0f9889a5..88b9bb507f 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
@@ -18,7 +18,6 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
@@ -27,7 +26,7 @@ from ..activations import get_activation
from ..attention_processor import Attention
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +624,7 @@ class HunyuanVideoDecoder3D(nn.Module):
return hidden_states
-class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
+class AutoencoderKLHunyuanVideo(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
@@ -764,27 +763,6 @@ class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
new file mode 100644
index 0000000000..616d0d4158
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py
@@ -0,0 +1,709 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageResnetBlock(nn.Module):
+ r"""
+ Residual block with two convolutions and optional channel change.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.nonlinearity = get_activation(non_linearity)
+
+ # layers
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if in_channels != out_channels:
+ self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+ else:
+ self.conv_shortcut = None
+
+ def forward(self, x):
+ # Apply shortcut connection
+ residual = x
+
+ # First normalization and activation
+ x = self.norm1(x)
+ x = self.nonlinearity(x)
+
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = self.nonlinearity(x)
+ x = self.conv2(x)
+
+ if self.conv_shortcut is not None:
+ x = self.conv_shortcut(x)
+ # Add residual connection
+ return x + residual
+
+
+class HunyuanImageAttentionBlock(nn.Module):
+ r"""
+ Self-attention with a single head.
+
+ Args:
+ in_channels (int): The number of channels in the input tensor.
+ """
+
+ def __init__(self, in_channels: int):
+ super().__init__()
+
+ # layers
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+ self.to_q = nn.Conv2d(in_channels, in_channels, 1)
+ self.to_k = nn.Conv2d(in_channels, in_channels, 1)
+ self.to_v = nn.Conv2d(in_channels, in_channels, 1)
+ self.proj = nn.Conv2d(in_channels, in_channels, 1)
+
+ def forward(self, x):
+ identity = x
+ x = self.norm(x)
+
+ # compute query, key, value
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ batch_size, channels, height, width = query.shape
+ query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+ key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+ value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous()
+
+ # apply attention
+ x = F.scaled_dot_product_attention(query, key, value)
+
+ x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)
+ # output projection
+ x = self.proj(x)
+
+ return x + identity
+
+
+class HunyuanImageDownsample(nn.Module):
+ """
+ Downsampling block for spatial reduction.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ factor = 4
+ if out_channels % factor != 0:
+ raise ValueError(f"out_channels % factor != 0: {out_channels % factor}")
+
+ self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
+ self.group_size = factor * in_channels // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv(x)
+
+ B, C, H, W = h.shape
+ h = h.reshape(B, C, H // 2, 2, W // 2, 2)
+ h = h.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
+ h = h.reshape(B, 4 * C, H // 2, W // 2)
+
+ B, C, H, W = x.shape
+ shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2)
+ shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) # b, r1, r2, c, h, w
+ shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2)
+
+ B, C, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2)
+ return h + shortcut
+
+
+class HunyuanImageUpsample(nn.Module):
+ """
+ Upsampling block for spatial expansion.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int):
+ super().__init__()
+ factor = 4
+ self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1)
+ self.repeats = factor * out_channels // in_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv(x)
+
+ B, C, H, W = h.shape
+ h = h.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
+ h = h.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
+ h = h.reshape(B, C // 4, H * 2, W * 2)
+
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
+
+ B, C, H, W = shortcut.shape
+ shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) # b, r1, r2, c, h, w
+ shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) # b, c, h, r1, w, r2
+ shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2)
+ return h + shortcut
+
+
+class HunyuanImageMidBlock(nn.Module):
+ """
+ Middle block for HunyuanImageVAE encoder and decoder.
+
+ Args:
+ in_channels (int): Number of input channels.
+ num_layers (int): Number of layers.
+ """
+
+ def __init__(self, in_channels: int, num_layers: int = 1):
+ super().__init__()
+
+ resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)]
+
+ attentions = []
+ for _ in range(num_layers):
+ attentions.append(HunyuanImageAttentionBlock(in_channels))
+ resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels))
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.resnets[0](x)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ x = attn(x)
+ x = resnet(x)
+
+ return x
+
+
+class HunyuanImageEncoder2D(nn.Module):
+ r"""
+ Encoder network that compresses input to latent representation.
+
+ Args:
+ in_channels (int): Number of input channels.
+ z_channels (int): Number of latent channels.
+ block_out_channels (list of int): Output channels for each block.
+ num_res_blocks (int): Number of residual blocks per block.
+ spatial_compression_ratio (int): Spatial downsampling factor.
+ non_linearity (str): Type of non-linearity to use. Default is "silu".
+ downsample_match_channel (bool): Whether to match channels during downsampling.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ z_channels: int,
+ block_out_channels: Tuple[int, ...],
+ num_res_blocks: int,
+ spatial_compression_ratio: int,
+ non_linearity: str = "silu",
+ downsample_match_channel: bool = True,
+ ):
+ super().__init__()
+ if block_out_channels[-1] % (2 * z_channels) != 0:
+ raise ValueError(
+ f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}"
+ )
+
+ self.in_channels = in_channels
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.spatial_compression_ratio = spatial_compression_ratio
+
+ self.group_size = block_out_channels[-1] // (2 * z_channels)
+ self.nonlinearity = get_activation(non_linearity)
+
+ # init block
+ self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ # downsample blocks
+ self.down_blocks = nn.ModuleList([])
+
+ block_in_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ block_out_channel = block_out_channels[i]
+ # residual blocks
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(
+ HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ # downsample block
+ if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
+ if downsample_match_channel:
+ block_out_channel = block_out_channels[i + 1]
+ self.down_blocks.append(
+ HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ # middle blocks
+ self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1)
+
+ # output blocks
+ # Output layers
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv_in(x)
+
+ ## downsamples
+ for down_block in self.down_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(down_block, x)
+ else:
+ x = down_block(x)
+
+ ## middle
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ x = self._gradient_checkpointing_func(self.mid_block, x)
+ else:
+ x = self.mid_block(x)
+
+ ## head
+ B, C, H, W = x.shape
+ residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2)
+
+ x = self.norm_out(x)
+ x = self.nonlinearity(x)
+ x = self.conv_out(x)
+ return x + residual
+
+
+class HunyuanImageDecoder2D(nn.Module):
+ r"""
+ Decoder network that reconstructs output from latent representation.
+
+ Args:
+ z_channels : int
+ Number of latent channels.
+ out_channels : int
+ Number of output channels.
+ block_out_channels : Tuple[int, ...]
+ Output channels for each block.
+ num_res_blocks : int
+ Number of residual blocks per block.
+ spatial_compression_ratio : int
+ Spatial upsampling factor.
+ upsample_match_channel : bool
+ Whether to match channels during upsampling.
+ non_linearity (str): Type of non-linearity to use. Default is "silu".
+ """
+
+ def __init__(
+ self,
+ z_channels: int,
+ out_channels: int,
+ block_out_channels: Tuple[int, ...],
+ num_res_blocks: int,
+ spatial_compression_ratio: int,
+ upsample_match_channel: bool = True,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ if block_out_channels[0] % z_channels != 0:
+ raise ValueError(
+ f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}"
+ )
+
+ self.z_channels = z_channels
+ self.block_out_channels = block_out_channels
+ self.num_res_blocks = num_res_blocks
+ self.repeat = block_out_channels[0] // z_channels
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
+
+ # Middle blocks with attention
+ self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1)
+
+ # Upsampling blocks
+ block_in_channel = block_out_channels[0]
+ self.up_blocks = nn.ModuleList()
+ for i in range(len(block_out_channels)):
+ block_out_channel = block_out_channels[i]
+ for _ in range(self.num_res_blocks + 1):
+ self.up_blocks.append(
+ HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel)
+ )
+ block_in_channel = block_out_channel
+
+ if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1:
+ if upsample_match_channel:
+ block_out_channel = block_out_channels[i + 1]
+ self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel))
+ block_in_channel = block_out_channel
+
+ # Output layers
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True)
+ self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h = self._gradient_checkpointing_func(self.mid_block, h)
+ else:
+ h = self.mid_block(h)
+
+ for up_block in self.up_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ h = self._gradient_checkpointing_func(up_block, h)
+ else:
+ h = up_block(h)
+ h = self.norm_out(h)
+ h = self.nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+ r"""
+ A VAE model for 2D images with spatial tiling support.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = False
+
+ # fmt: off
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ latent_channels: int,
+ block_out_channels: Tuple[int, ...],
+ layers_per_block: int,
+ spatial_compression_ratio: int,
+ sample_size: int,
+ scaling_factor: float = None,
+ downsample_match_channel: bool = True,
+ upsample_match_channel: bool = True,
+ ) -> None:
+ # fmt: on
+ super().__init__()
+
+ self.encoder = HunyuanImageEncoder2D(
+ in_channels=in_channels,
+ z_channels=latent_channels,
+ block_out_channels=block_out_channels,
+ num_res_blocks=layers_per_block,
+ spatial_compression_ratio=spatial_compression_ratio,
+ downsample_match_channel=downsample_match_channel,
+ )
+
+ self.decoder = HunyuanImageDecoder2D(
+ z_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=list(reversed(block_out_channels)),
+ num_res_blocks=layers_per_block,
+ spatial_compression_ratio=spatial_compression_ratio,
+ upsample_match_channel=upsample_match_channel,
+ )
+
+ # Tiling and slicing configuration
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # Tiling parameters
+ self.tile_sample_min_size = sample_size
+ self.tile_latent_min_size = sample_size // spatial_compression_ratio
+ self.tile_overlap_factor = 0.25
+
+ def enable_tiling(
+ self,
+ tile_sample_min_size: Optional[int] = None,
+ tile_overlap_factor: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles
+ to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to
+ allow processing larger images.
+
+ Args:
+ tile_sample_min_size (`int`, *optional*):
+ The minimum size required for a sample to be separated into tiles across the spatial dimension.
+ tile_overlap_factor (`float`, *optional*):
+ The overlap factor required for a latent to be separated into tiles across the spatial dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size
+ self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
+ self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor):
+
+ batch_size, num_channels, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
+ return self.tiled_encode(x)
+
+ enc = self.encoder(x)
+
+ return enc
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
+
+ batch_size, num_channels, height, width = z.shape
+
+ if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+ return DecoderOutput(sample=decoded)
+
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Encode input using spatial tiling strategy.
+
+ Args:
+ x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W).
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded images.
+ """
+ _, _, _, height, width = x.shape
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ rows = []
+ for i in range(0, height, overlap_size):
+ row = []
+ for j in range(0, width, overlap_size):
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ moments = torch.cat(result_rows, dim=-2)
+
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Decode latent using spatial tiling strategy.
+
+ Args:
+ z (`torch.Tensor`): Latent tensor of shape (B, C, H, W).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ _, _, height, width = z.shape
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ rows = []
+ for i in range(0, height, overlap_size):
+ row = []
+ for j in range(0, width, overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=-2)
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ """
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ posterior = self.encode(sample).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
new file mode 100644
index 0000000000..af40c7a6cb
--- /dev/null
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage_refiner.py
@@ -0,0 +1,934 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ...utils.accelerate_utils import apply_forward_hook
+from ..activations import get_activation
+from ..modeling_outputs import AutoencoderKLOutput
+from ..modeling_utils import ModelMixin
+from .vae import DecoderOutput, DiagonalGaussianDistribution
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageRefinerCausalConv3d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
+ stride: Union[int, Tuple[int, int, int]] = 1,
+ padding: Union[int, Tuple[int, int, int]] = 0,
+ dilation: Union[int, Tuple[int, int, int]] = 1,
+ bias: bool = True,
+ pad_mode: str = "replicate",
+ ) -> None:
+ super().__init__()
+
+ kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
+
+ self.pad_mode = pad_mode
+ self.time_causal_padding = (
+ kernel_size[0] // 2,
+ kernel_size[0] // 2,
+ kernel_size[1] // 2,
+ kernel_size[1] // 2,
+ kernel_size[2] - 1,
+ 0,
+ )
+
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
+ return self.conv(hidden_states)
+
+
+class HunyuanImageRefinerRMS_norm(nn.Module):
+ r"""
+ A custom RMS normalization layer.
+
+ Args:
+ dim (int): The number of dimensions to normalize over.
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
+ Default is True.
+ images (bool, optional): Whether the input represents image data. Default is True.
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
+ """
+
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
+
+ def forward(self, x):
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
+
+
+class HunyuanImageRefinerAttnBlock(nn.Module):
+ def __init__(self, in_channels: int):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = HunyuanImageRefinerRMS_norm(in_channels, images=False)
+
+ self.to_q = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_k = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.to_v = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+ self.proj_out = nn.Conv3d(in_channels, in_channels, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ identity = x
+
+ x = self.norm(x)
+
+ query = self.to_q(x)
+ key = self.to_k(x)
+ value = self.to_v(x)
+
+ batch_size, channels, frames, height, width = query.shape
+
+ query = query.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ key = key.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+ value = value.reshape(batch_size, channels, frames * height * width).permute(0, 2, 1).unsqueeze(1).contiguous()
+
+ x = nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None)
+
+ # batch_size, 1, frames * height * width, channels
+
+ x = x.squeeze(1).reshape(batch_size, frames, height, width, channels).permute(0, 4, 1, 2, 3)
+ x = self.proj_out(x)
+
+ return x + identity
+
+
+class HunyuanImageRefinerUpsampleDCAE(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
+ self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels * factor, kernel_size=3)
+
+ self.add_temporal_upsample = add_temporal_upsample
+ self.repeats = factor * out_channels // in_channels
+
+ @staticmethod
+ def _dcae_upsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
+
+ Args:
+ tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
+ r1: temporal upsampling factor
+ r2: height upsampling factor
+ r3: width upsampling factor
+ """
+ b, packed_c, f, h, w = tensor.shape
+ factor = r1 * r2 * r3
+ c = packed_c // factor
+
+ tensor = tensor.view(b, r1, r2, r3, c, f, h, w)
+ tensor = tensor.permute(0, 4, 5, 1, 6, 2, 7, 3)
+ return tensor.reshape(b, c, f * r1, h * r2, w * r3)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_upsample else 1
+ h = self.conv(x)
+ if self.add_temporal_upsample:
+ h = self._dcae_upsample_rearrange(h, r1=1, r2=2, r3=2)
+ h = h[:, : h.shape[1] // 2]
+
+ # shortcut computation
+ shortcut = self._dcae_upsample_rearrange(x, r1=1, r2=2, r3=2)
+ shortcut = shortcut.repeat_interleave(repeats=self.repeats // 2, dim=1)
+
+ else:
+ h = self._dcae_upsample_rearrange(h, r1=r1, r2=2, r3=2)
+ shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
+ shortcut = self._dcae_upsample_rearrange(shortcut, r1=r1, r2=2, r3=2)
+ return h + shortcut
+
+
+class HunyuanImageRefinerDownsampleDCAE(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
+ super().__init__()
+ factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
+ assert out_channels % factor == 0
+ # self.conv = Conv3d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1)
+ self.conv = HunyuanImageRefinerCausalConv3d(in_channels, out_channels // factor, kernel_size=3)
+
+ self.add_temporal_downsample = add_temporal_downsample
+ self.group_size = factor * in_channels // out_channels
+
+ @staticmethod
+ def _dcae_downsample_rearrange(tensor, r1=1, r2=2, r3=2):
+ """
+ Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
+
+ This packs spatial/temporal dimensions into channels (opposite of upsample)
+ """
+ b, c, packed_f, packed_h, packed_w = tensor.shape
+ f, h, w = packed_f // r1, packed_h // r2, packed_w // r3
+
+ tensor = tensor.view(b, c, f, r1, h, r2, w, r3)
+ tensor = tensor.permute(0, 3, 5, 7, 1, 2, 4, 6)
+ return tensor.reshape(b, r1 * r2 * r3 * c, f, h, w)
+
+ def forward(self, x: torch.Tensor):
+ r1 = 2 if self.add_temporal_downsample else 1
+ h = self.conv(x)
+ if self.add_temporal_downsample:
+ # h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
+ h = self._dcae_downsample_rearrange(h, r1=1, r2=2, r3=2)
+ h = torch.cat([h, h], dim=1)
+ # shortcut computation
+ # shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
+ shortcut = self._dcae_downsample_rearrange(x, r1=1, r2=2, r3=2)
+ B, C, T, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
+ else:
+ # h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
+ h = self._dcae_downsample_rearrange(h, r1=r1, r2=2, r3=2)
+ # shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
+ shortcut = self._dcae_downsample_rearrange(x, r1=r1, r2=2, r3=2)
+ B, C, T, H, W = shortcut.shape
+ shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
+
+ return h + shortcut
+
+
+class HunyuanImageRefinerResnetBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: Optional[int] = None,
+ non_linearity: str = "swish",
+ ) -> None:
+ super().__init__()
+ out_channels = out_channels or in_channels
+
+ self.nonlinearity = get_activation(non_linearity)
+
+ self.norm1 = HunyuanImageRefinerRMS_norm(in_channels, images=False)
+ self.conv1 = HunyuanImageRefinerCausalConv3d(in_channels, out_channels, kernel_size=3)
+
+ self.norm2 = HunyuanImageRefinerRMS_norm(out_channels, images=False)
+ self.conv2 = HunyuanImageRefinerCausalConv3d(out_channels, out_channels, kernel_size=3)
+
+ self.conv_shortcut = None
+ if in_channels != out_channels:
+ self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv1(hidden_states)
+
+ hidden_states = self.norm2(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ residual = self.conv_shortcut(residual)
+
+ return hidden_states + residual
+
+
+class HunyuanImageRefinerMidBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_layers: int = 1,
+ add_attention: bool = True,
+ ) -> None:
+ super().__init__()
+ self.add_attention = add_attention
+
+ # There is always at least one resnet
+ resnets = [
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(HunyuanImageRefinerAttnBlock(in_channels))
+ else:
+ attentions.append(None)
+
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.resnets[0](hidden_states)
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ downsample_out_channels: Optional[int] = None,
+ add_temporal_downsample: int = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_out_channels is not None:
+ self.downsamplers = nn.ModuleList(
+ [
+ HunyuanImageRefinerDownsampleDCAE(
+ out_channels,
+ out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ upsample_out_channels: Optional[int] = None,
+ add_temporal_upsample: bool = True,
+ ) -> None:
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ HunyuanImageRefinerResnetBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_out_channels is not None:
+ self.upsamplers = nn.ModuleList(
+ [
+ HunyuanImageRefinerUpsampleDCAE(
+ out_channels,
+ out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for resnet in self.resnets:
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
+
+ else:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class HunyuanImageRefinerEncoder3D(nn.Module):
+ r"""
+ 3D vae encoder for HunyuanImageRefiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 64,
+ block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ temporal_compression_ratio: int = 4,
+ spatial_compression_ratio: int = 16,
+ downsample_match_channel: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.group_size = block_out_channels[-1] // self.out_channels
+
+ self.conv_in = HunyuanImageRefinerCausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
+ self.mid_block = None
+ self.down_blocks = nn.ModuleList([])
+
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ add_spatial_downsample = i < np.log2(spatial_compression_ratio)
+ output_channel = block_out_channels[i]
+ if not add_spatial_downsample:
+ down_block = HunyuanImageRefinerDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=None,
+ add_temporal_downsample=False,
+ )
+ input_channel = output_channel
+ else:
+ add_temporal_downsample = i >= np.log2(spatial_compression_ratio // temporal_compression_ratio)
+ downsample_out_channels = block_out_channels[i + 1] if downsample_match_channel else output_channel
+ down_block = HunyuanImageRefinerDownBlock3D(
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ downsample_out_channels=downsample_out_channels,
+ add_temporal_downsample=add_temporal_downsample,
+ )
+ input_channel = downsample_out_channels
+
+ self.down_blocks.append(down_block)
+
+ self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[-1])
+
+ self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for down_block in self.down_blocks:
+ hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
+
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+ else:
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ hidden_states = self.mid_block(hidden_states)
+
+ # short_cut = rearrange(hidden_states, "b (c r) f h w -> b c r f h w", r=self.group_size).mean(dim=2)
+ batch_size, _, frame, height, width = hidden_states.shape
+ short_cut = hidden_states.view(batch_size, -1, self.group_size, frame, height, width).mean(dim=2)
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+
+ hidden_states += short_cut
+
+ return hidden_states
+
+
+class HunyuanImageRefinerDecoder3D(nn.Module):
+ r"""
+ Causal decoder for 3D video-like data used for HunyuanImage-2.1 Refiner.
+ """
+
+ def __init__(
+ self,
+ in_channels: int = 32,
+ out_channels: int = 3,
+ block_out_channels: Tuple[int, ...] = (1024, 1024, 512, 256, 128),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ upsample_match_channel: bool = True,
+ ):
+ super().__init__()
+ self.layers_per_block = layers_per_block
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.repeat = block_out_channels[0] // self.in_channels
+
+ self.conv_in = HunyuanImageRefinerCausalConv3d(self.in_channels, block_out_channels[0], kernel_size=3)
+ self.up_blocks = nn.ModuleList([])
+
+ # mid
+ self.mid_block = HunyuanImageRefinerMidBlock(in_channels=block_out_channels[0])
+
+ # up
+ input_channel = block_out_channels[0]
+ for i in range(len(block_out_channels)):
+ output_channel = block_out_channels[i]
+
+ add_spatial_upsample = i < np.log2(spatial_compression_ratio)
+ add_temporal_upsample = i < np.log2(temporal_compression_ratio)
+ if add_spatial_upsample or add_temporal_upsample:
+ upsample_out_channels = block_out_channels[i + 1] if upsample_match_channel else output_channel
+ up_block = HunyuanImageRefinerUpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=upsample_out_channels,
+ add_temporal_upsample=add_temporal_upsample,
+ )
+ input_channel = upsample_out_channels
+ else:
+ up_block = HunyuanImageRefinerUpBlock3D(
+ num_layers=self.layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ upsample_out_channels=None,
+ add_temporal_upsample=False,
+ )
+ input_channel = output_channel
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.norm_out = HunyuanImageRefinerRMS_norm(block_out_channels[-1], images=False)
+ self.conv_act = nn.SiLU()
+ self.conv_out = HunyuanImageRefinerCausalConv3d(block_out_channels[-1], out_channels, kernel_size=3)
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states) + hidden_states.repeat_interleave(repeats=self.repeat, dim=1)
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
+ else:
+ hidden_states = self.mid_block(hidden_states)
+
+ for up_block in self.up_blocks:
+ hidden_states = up_block(hidden_states)
+
+ # post-process
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
+ r"""
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
+ HunyuanImage-2.1 Refiner.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ latent_channels: int = 32,
+ block_out_channels: Tuple[int] = (128, 256, 512, 1024, 1024),
+ layers_per_block: int = 2,
+ spatial_compression_ratio: int = 16,
+ temporal_compression_ratio: int = 4,
+ downsample_match_channel: bool = True,
+ upsample_match_channel: bool = True,
+ scaling_factor: float = 1.03682,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = HunyuanImageRefinerEncoder3D(
+ in_channels=in_channels,
+ out_channels=latent_channels * 2,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ downsample_match_channel=downsample_match_channel,
+ )
+
+ self.decoder = HunyuanImageRefinerDecoder3D(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ block_out_channels=list(reversed(block_out_channels)),
+ layers_per_block=layers_per_block,
+ temporal_compression_ratio=temporal_compression_ratio,
+ spatial_compression_ratio=spatial_compression_ratio,
+ upsample_match_channel=upsample_match_channel,
+ )
+
+ self.spatial_compression_ratio = spatial_compression_ratio
+ self.temporal_compression_ratio = temporal_compression_ratio
+
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
+ # to perform decoding of a single video latent at a time.
+ self.use_slicing = False
+
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
+ # intermediate tiles together, the memory requirement can be lowered.
+ self.use_tiling = False
+
+ # The minimal tile height and width for spatial tiling to be used
+ self.tile_sample_min_height = 256
+ self.tile_sample_min_width = 256
+
+ # The minimal distance between two spatial tiles
+ self.tile_sample_stride_height = 192
+ self.tile_sample_stride_width = 192
+
+ self.tile_overlap_factor = 0.25
+
+ def enable_tiling(
+ self,
+ tile_sample_min_height: Optional[int] = None,
+ tile_sample_min_width: Optional[int] = None,
+ tile_sample_stride_height: Optional[float] = None,
+ tile_sample_stride_width: Optional[float] = None,
+ tile_overlap_factor: Optional[float] = None,
+ ) -> None:
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+
+ Args:
+ tile_sample_min_height (`int`, *optional*):
+ The minimum height required for a sample to be separated into tiles across the height dimension.
+ tile_sample_min_width (`int`, *optional*):
+ The minimum width required for a sample to be separated into tiles across the width dimension.
+ tile_sample_stride_height (`int`, *optional*):
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
+ no tiling artifacts produced across the height dimension.
+ tile_sample_stride_width (`int`, *optional*):
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
+ artifacts produced across the width dimension.
+ """
+ self.use_tiling = True
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
+ self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
+
+ def disable_tiling(self) -> None:
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self) -> None:
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self) -> None:
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = x.shape
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
+ x = self.encoder(x)
+ return x
+
+ @apply_forward_hook
+ def encode(
+ self, x: torch.Tensor, return_dict: bool = True
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
+ r"""
+ Encode a batch of images into latents.
+
+ Args:
+ x (`torch.Tensor`): Input batch of images.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
+
+ Returns:
+ The latent representations of the encoded videos. If `return_dict` is True, a
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
+ """
+ if self.use_slicing and x.shape[0] > 1:
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
+ h = torch.cat(encoded_slices)
+ else:
+ h = self._encode(x)
+
+ posterior = DiagonalGaussianDistribution(h)
+
+ if not return_dict:
+ return (posterior,)
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
+ _, _, _, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
+ return self.tiled_decode(z)
+
+ dec = self.decoder(z)
+
+ return dec
+
+ @apply_forward_hook
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Decode a batch of images.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z)
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
+ for y in range(blend_extent):
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
+ y / blend_extent
+ )
+ return b
+
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
+ x / blend_extent
+ )
+ return b
+
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
+ for x in range(blend_extent):
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
+ x / blend_extent
+ )
+ return b
+
+ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ x (`torch.Tensor`): Input batch of videos.
+
+ Returns:
+ `torch.Tensor`:
+ The latent representation of the encoded videos.
+ """
+ _, _, _, height, width = x.shape
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 256 * (1 - 0.25) = 192
+ blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 8 * 0.25 = 2
+ blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 8 * 0.25 = 2
+ row_limit_height = tile_latent_min_height - blend_height # 8 - 2 = 6
+ row_limit_width = tile_latent_min_width - blend_width # 8 - 2 = 6
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile)
+ row.append(tile)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ moments = torch.cat(result_rows, dim=-2)
+
+ return moments
+
+ def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
+ r"""
+ Decode a batch of images using a tiled decoder.
+
+ Args:
+ z (`torch.Tensor`): Input batch of latent vectors.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.vae.DecoderOutput`] or `tuple`:
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
+ returned.
+ """
+
+ _, _, _, height, width = z.shape
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ overlap_height = int(tile_latent_min_height * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ overlap_width = int(tile_latent_min_width * (1 - self.tile_overlap_factor)) # 8 * (1 - 0.25) = 6
+ blend_height = int(tile_latent_min_height * self.tile_overlap_factor) # 256 * 0.25 = 64
+ blend_width = int(tile_latent_min_width * self.tile_overlap_factor) # 256 * 0.25 = 64
+ row_limit_height = tile_latent_min_height - blend_height # 256 - 64 = 192
+ row_limit_width = tile_latent_min_width - blend_width # 256 - 64 = 192
+
+ rows = []
+ for i in range(0, height, overlap_height):
+ row = []
+ for j in range(0, width, overlap_width):
+ tile = z[
+ :,
+ :,
+ :,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
+ result_rows.append(torch.cat(result_row, dim=-1))
+ dec = torch.cat(result_rows, dim=-2)
+
+ return dec
+
+ def forward(
+ self,
+ sample: torch.Tensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.Tensor]:
+ r"""
+ Args:
+ sample (`torch.Tensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z, return_dict=return_dict)
+ return dec
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
index 51c600a4e9..47f2081b7e 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
@@ -26,7 +26,7 @@ from ..embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class LTXVideoCausalConv3d(nn.Module):
@@ -1034,7 +1034,7 @@ class LTXVideoDecoder3d(nn.Module):
return hidden_states
-class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLLTXVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -1219,27 +1219,6 @@ class AutoencoderKLLTXVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = x.shape
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
index 43294a901f..97ca9d6692 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_magvit.py
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -663,7 +663,7 @@ class EasyAnimateDecoder(nn.Module):
return hidden_states
-class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
+class AutoencoderKLMagvit(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. This
model is used in [EasyAnimate](https://huggingface.co/papers/2405.18991).
@@ -805,27 +805,6 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin):
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@apply_forward_hook
def _encode(
self, x: torch.Tensor, return_dict: bool = True
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
index 404d2f6d86..3ded9a0a54 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
@@ -27,7 +27,7 @@ from ..attention_processor import Attention, MochiVaeAttnProcessor2_0
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .autoencoder_kl_cogvideox import CogVideoXCausalConv3d
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -657,7 +657,7 @@ class MochiDecoder3D(nn.Module):
return hidden_states, new_conv_cache
-class AutoencoderKLMochi(ModelMixin, ConfigMixin):
+class AutoencoderKLMochi(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
@@ -818,27 +818,6 @@ class AutoencoderKLMochi(ModelMixin, ConfigMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def _enable_framewise_encoding(self):
r"""
Enables the framewise VAE encoding implementation with past latent padding. By default, Diffusers uses the
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
index 87ac406592..14db6aeb61 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
@@ -23,7 +23,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
@@ -32,7 +31,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -664,7 +663,7 @@ class QwenImageDecoder3d(nn.Module):
return x
-class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLQwenImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -764,27 +763,6 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def clear_cache(self):
def _count_conv3d(model):
count = 0
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
index cf46e52564..ab76254d19 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py
@@ -23,7 +23,7 @@ from ..attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from ..unets.unet_3d_blocks import MidBlockTemporalDecoder, UpBlockTemporalDecoder
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
class TemporalDecoder(nn.Module):
@@ -135,7 +135,7 @@ class TemporalDecoder(nn.Module):
return sample
-class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
+class AutoencoderKLTemporalDecoder(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
index d84a0861e9..f8bdfeb755 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -17,7 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
@@ -26,7 +25,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DiagonalGaussianDistribution
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -454,14 +453,14 @@ class WanMidBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
# First residual block
- x = self.resnets[0](x, feat_cache, feat_idx)
+ x = self.resnets[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
# Process through attention and residual blocks
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
x = attn(x)
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x
@@ -495,9 +494,9 @@ class WanResidualDownBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for resnet in self.resnets:
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
if self.downsampler is not None:
- x = self.downsampler(x, feat_cache, feat_idx)
+ x = self.downsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
return x + self.avg_shortcut(x_copy)
@@ -599,12 +598,12 @@ class WanEncoder3d(nn.Module):
## downsamples
for layer in self.down_blocks:
if feat_cache is not None:
- x = layer(x, feat_cache, feat_idx)
+ x = layer(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = layer(x)
## middle
- x = self.mid_block(x, feat_cache, feat_idx)
+ x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## head
x = self.norm_out(x)
@@ -695,13 +694,13 @@ class WanResidualUpBlock(nn.Module):
for resnet in self.resnets:
if feat_cache is not None:
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsampler is not None:
if feat_cache is not None:
- x = self.upsampler(x, feat_cache, feat_idx)
+ x = self.upsampler(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsampler(x)
@@ -768,13 +767,13 @@ class WanUpBlock(nn.Module):
"""
for resnet in self.resnets:
if feat_cache is not None:
- x = resnet(x, feat_cache, feat_idx)
+ x = resnet(x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = resnet(x)
if self.upsamplers is not None:
if feat_cache is not None:
- x = self.upsamplers[0](x, feat_cache, feat_idx)
+ x = self.upsamplers[0](x, feat_cache=feat_cache, feat_idx=feat_idx)
else:
x = self.upsamplers[0](x)
return x
@@ -886,11 +885,11 @@ class WanDecoder3d(nn.Module):
x = self.conv_in(x)
## middle
- x = self.mid_block(x, feat_cache, feat_idx)
+ x = self.mid_block(x, feat_cache=feat_cache, feat_idx=feat_idx)
## upsamples
for up_block in self.up_blocks:
- x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
+ x = up_block(x, feat_cache=feat_cache, feat_idx=feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -952,7 +951,7 @@ def unpatchify(x, patch_size):
return x
-class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
+class AutoencoderKLWan(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
Introduced in [Wan 2.1].
@@ -962,6 +961,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
_supports_gradient_checkpointing = False
+ # keys toignore when AlignDeviceHook moves inputs/outputs between devices
+ # these are shared mutable state modified in-place
+ _skip_keys = ["feat_cache", "feat_idx"]
@register_to_config
def __init__(
@@ -1052,7 +1054,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
is_residual=is_residual,
)
- self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
+ self.spatial_compression_ratio = scale_factor_spatial
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
@@ -1111,27 +1113,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_tiling = False
-
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
def clear_cache(self):
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
@@ -1145,12 +1126,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def _encode(self, x: torch.Tensor):
_, _, num_frame, height, width = x.shape
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
- return self.tiled_encode(x)
-
self.clear_cache()
if self.config.patch_size is not None:
x = patchify(x, patch_size=self.config.patch_size)
+
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
+ return self.tiled_encode(x)
+
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
@@ -1355,9 +1337,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
-
- blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
- blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+ tile_sample_stride_height = self.tile_sample_stride_height
+ tile_sample_stride_width = self.tile_sample_stride_width
+ if self.config.patch_size is not None:
+ sample_height = sample_height // self.config.patch_size
+ sample_width = sample_width // self.config.patch_size
+ tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
+ tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
+ blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
+ else:
+ blend_height = self.tile_sample_min_height - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
@@ -1371,7 +1362,9 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self._conv_idx = [0]
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
tile = self.post_quant_conv(tile)
- decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ decoded = self.decoder(
+ tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0)
+ )
time.append(decoded)
row.append(torch.cat(time, dim=2))
rows.append(row)
@@ -1387,11 +1380,15 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_width)
- result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
+ result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width])
result_rows.append(torch.cat(result_row, dim=-1))
-
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+ if self.config.patch_size is not None:
+ dec = unpatchify(dec, patch_size=self.config.patch_size)
+
+ dec = torch.clamp(dec, min=-1.0, max=1.0)
+
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
index a10b616b4e..d832645592 100644
--- a/src/diffusers/models/autoencoders/autoencoder_oobleck.py
+++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py
@@ -25,6 +25,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.torch_utils import randn_tensor
from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin
class Snake1d(nn.Module):
@@ -291,7 +292,7 @@ class OobleckDecoder(nn.Module):
return hidden_state
-class AutoencoderOobleck(ModelMixin, ConfigMixin):
+class AutoencoderOobleck(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First
introduced in Stable Audio.
@@ -356,20 +357,6 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
self.use_slicing = False
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
diff --git a/src/diffusers/models/autoencoders/autoencoder_tiny.py b/src/diffusers/models/autoencoders/autoencoder_tiny.py
index 3e2b28606e..b9ac713d73 100644
--- a/src/diffusers/models/autoencoders/autoencoder_tiny.py
+++ b/src/diffusers/models/autoencoders/autoencoder_tiny.py
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_utils import ModelMixin
-from .vae import DecoderOutput, DecoderTiny, EncoderTiny
+from .vae import AutoencoderMixin, DecoderOutput, DecoderTiny, EncoderTiny
@dataclass
@@ -38,7 +38,7 @@ class AutoencoderTinyOutput(BaseOutput):
latents: torch.Tensor
-class AutoencoderTiny(ModelMixin, ConfigMixin):
+class AutoencoderTiny(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images.
@@ -162,35 +162,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
"""[0, 1] -> raw latents"""
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude)
- def enable_slicing(self) -> None:
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- def disable_slicing(self) -> None:
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
- def enable_tiling(self, use_tiling: bool = True) -> None:
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- def disable_tiling(self) -> None:
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
diff --git a/src/diffusers/models/autoencoders/consistency_decoder_vae.py b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
index b3017a8780..0a6258fed3 100644
--- a/src/diffusers/models/autoencoders/consistency_decoder_vae.py
+++ b/src/diffusers/models/autoencoders/consistency_decoder_vae.py
@@ -32,7 +32,7 @@ from ..attention_processor import (
)
from ..modeling_utils import ModelMixin
from ..unets.unet_2d import UNet2DModel
-from .vae import DecoderOutput, DiagonalGaussianDistribution, Encoder
+from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
@@ -49,7 +49,7 @@ class ConsistencyDecoderVAEOutput(BaseOutput):
latent_dist: "DiagonalGaussianDistribution"
-class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
+class ConsistencyDecoderVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
The consistency decoder used with DALL-E 3.
@@ -167,39 +167,6 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
- def enable_tiling(self, use_tiling: bool = True):
- r"""
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
- processing larger images.
- """
- self.use_tiling = use_tiling
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_tiling
- def disable_tiling(self):
- r"""
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.enable_tiling(False)
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_slicing
- def enable_slicing(self):
- r"""
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
- """
- self.use_slicing = True
-
- # Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.disable_slicing
- def disable_slicing(self):
- r"""
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
- decoding in one step.
- """
- self.use_slicing = False
-
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py
index 1d74d4f472..9c6031a988 100644
--- a/src/diffusers/models/autoencoders/vae.py
+++ b/src/diffusers/models/autoencoders/vae.py
@@ -286,11 +286,9 @@ class Decoder(nn.Module):
sample = self.conv_in(sample)
- upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
- sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -298,7 +296,6 @@ class Decoder(nn.Module):
else:
# middle
sample = self.mid_block(sample, latent_embeds)
- sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
@@ -894,3 +891,38 @@ class DecoderTiny(nn.Module):
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
+
+
+class AutoencoderMixin:
+ def enable_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ if not hasattr(self, "use_tiling"):
+ raise NotImplementedError(f"Tiling doesn't seem to be implemented for {self.__class__.__name__}.")
+ self.use_tiling = True
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_tiling = False
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ if not hasattr(self, "use_slicing"):
+ raise NotImplementedError(f"Slicing doesn't seem to be implemented for {self.__class__.__name__}.")
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
diff --git a/src/diffusers/models/autoencoders/vq_model.py b/src/diffusers/models/autoencoders/vq_model.py
index c1094e62f7..82436473df 100644
--- a/src/diffusers/models/autoencoders/vq_model.py
+++ b/src/diffusers/models/autoencoders/vq_model.py
@@ -22,6 +22,7 @@ from ...utils import BaseOutput
from ...utils.accelerate_utils import apply_forward_hook
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
from ..modeling_utils import ModelMixin
+from .vae import AutoencoderMixin
@dataclass
@@ -37,7 +38,7 @@ class VQEncoderOutput(BaseOutput):
latents: torch.Tensor
-class VQModel(ModelMixin, ConfigMixin):
+class VQModel(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VQ-VAE model for decoding latent representations.
diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py
index 8d892cb3b6..0641c8bc01 100644
--- a/src/diffusers/models/controlnets/controlnet_sd3.py
+++ b/src/diffusers/models/controlnets/controlnet_sd3.py
@@ -270,11 +270,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -294,11 +290,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/controlnets/controlnet_xs.py b/src/diffusers/models/controlnets/controlnet_xs.py
index aabae709e9..f5c69b9a46 100644
--- a/src/diffusers/models/controlnets/controlnet_xs.py
+++ b/src/diffusers/models/controlnets/controlnet_xs.py
@@ -16,7 +16,6 @@ from math import gcd
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import Tensor, nn
from ...configuration_utils import ConfigMixin, register_to_config
@@ -980,11 +979,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1004,11 +999,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index b51f5d7aec..37fc412adc 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -319,13 +319,17 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False, dtype=None):
"""
This function generates 1D positional embeddings from a grid.
Args:
embed_dim (`int`): The embedding dimension `D`
pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
+ output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
+ dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
+ `torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
Returns:
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -341,7 +345,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
- omega = torch.arange(embed_dim // 2, device=pos.device, dtype=torch.float64)
+ # Auto-detect appropriate dtype if not specified
+ if dtype is None:
+ dtype = torch.float32 if pos.device.type == "mps" else torch.float64
+
+ omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 573828dc4b..3f06099319 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -26,11 +26,11 @@ from flax.traverse_util import flatten_dict, unflatten_dict
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
validate_hf_hub_args,
)
-from requests import HTTPError
from .. import __version__, is_torch_available
from ..utils import (
@@ -113,14 +113,14 @@ class FlaxModelMixin(PushToHubMixin):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model parameters will be in fp32 precision, to cast these to bfloat16 precision
>>> params = model.to_bf16(params)
>>> # If you don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
@@ -149,7 +149,7 @@ class FlaxModelMixin(PushToHubMixin):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to illustrate the use of this method,
>>> # we'll first cast to fp16 and back to fp32
>>> params = model.to_f16(params)
@@ -179,14 +179,14 @@ class FlaxModelMixin(PushToHubMixin):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # load model
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # By default, the model params will be in fp32, to cast these to float16
>>> params = model.to_fp16(params)
>>> # If you want don't want to cast certain parameters (for example layer norm bias and scale)
>>> # then pass the mask as follows
>>> from flax import traverse_util
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> flat_params = traverse_util.flatten_dict(params)
>>> mask = {
... path: (path[-2] != ("LayerNorm", "bias") and path[-2:] != ("LayerNorm", "scale"))
@@ -216,8 +216,8 @@ class FlaxModelMixin(PushToHubMixin):
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- - A string, the *model id* (for example `runwayml/stable-diffusion-v1-5`) of a pretrained model
- hosted on the Hub.
+ - A string, the *model id* (for example `stable-diffusion-v1-5/stable-diffusion-v1-5`) of a
+ pretrained model hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
using [`~FlaxModelMixin.save_pretrained`].
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
@@ -227,15 +227,9 @@ class FlaxModelMixin(PushToHubMixin):
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
specified, all the computation will be performed with the given `dtype`.
-
-
- This only specifies the dtype of the *computation* and does not influence the dtype of model
- parameters.
-
- If you wish to change the dtype of the model parameters, see [`~FlaxModelMixin.to_fp16`] and
- [`~FlaxModelMixin.to_bf16`].
-
-
+ > [!TIP] > This only specifies the dtype of the *computation* and does not influence the dtype of model
+ > parameters. > > If you wish to change the dtype of the model parameters, see
+ [`~FlaxModelMixin.to_fp16`] and > [`~FlaxModelMixin.to_bf16`].
model_args (sequence of positional arguments, *optional*):
All remaining positional arguments are passed to the underlying model's `__init__` method.
@@ -277,7 +271,7 @@ class FlaxModelMixin(PushToHubMixin):
>>> from diffusers import FlaxUNet2DConditionModel
>>> # Download model and configuration from huggingface.co and cache.
- >>> model, params = FlaxUNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5")
+ >>> model, params = FlaxUNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
>>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
>>> model, params = FlaxUNet2DConditionModel.from_pretrained("./test/saved_model/")
```
@@ -285,7 +279,7 @@ class FlaxModelMixin(PushToHubMixin):
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -385,7 +379,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {FLAX_WEIGHTS_NAME}."
)
- except HTTPError as err:
+ except HfHubHTTPError as err:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n"
f"{err}"
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 2388989be2..91daca1ad8 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -65,6 +65,7 @@ from ..utils.hub_utils import (
populate_model_card,
)
from ..utils.torch_utils import empty_device_cache
+from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
@@ -248,6 +249,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
_repeated_blocks = []
+ _parallel_config = None
+ _cp_plan = None
+ _skip_keys = None
def __init__(self):
super().__init__()
@@ -400,12 +404,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
inference. Speed up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -620,8 +620,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
def reset_attention_backend(self) -> None:
"""
- Resets the attention backend for the model. Following calls to `forward` will use the environment default or
- the torch native scaled dot product attention.
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default, if
+ set, or the torch native scaled dot product attention.
"""
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -914,27 +914,23 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
Example:
```py
from diffusers import UNet2DConditionModel
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
+ unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
```
If you get the error message below, you need to finetune the weights for your downstream task:
```bash
- Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at stable-diffusion-v1-5/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
```
@@ -960,6 +956,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
+ parallel_config: Optional[Union[ParallelConfig, ContextParallelConfig]] = kwargs.pop("parallel_config", None)
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1340,6 +1337,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Set model in evaluation mode to deactivate DropOut modules by default
model.eval()
+ if parallel_config is not None:
+ model.enable_parallelism(config=parallel_config)
+
if output_loading_info:
return model, loading_info
@@ -1478,6 +1478,73 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
)
+ def enable_parallelism(
+ self,
+ *,
+ config: Union[ParallelConfig, ContextParallelConfig],
+ cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
+ ):
+ from ..hooks.context_parallel import apply_context_parallel
+ from .attention import AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning(
+ "`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
+ )
+
+ if isinstance(config, ContextParallelConfig):
+ config = ParallelConfig(context_parallel_config=config)
+
+ if not torch.distributed.is_initialized():
+ raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
+
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+ device_type = torch._C._get_accelerator().type
+ device_module = torch.get_device_module(device_type)
+ device = torch.device(device_type, rank % device_module.device_count())
+
+ cp_mesh = None
+ if config.context_parallel_config is not None:
+ cp_config = config.context_parallel_config
+ if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
+ raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
+ if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
+ raise ValueError(
+ "Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
+ )
+ if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
+ raise ValueError(
+ f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
+ )
+ cp_mesh = torch.distributed.device_mesh.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
+ mesh_dim_names=("ring", "ulysses"),
+ )
+
+ config.setup(rank, world_size, device, cp_mesh=cp_mesh)
+
+ if cp_plan is None and self._cp_plan is None:
+ raise ValueError(
+ "`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
+ )
+ cp_plan = cp_plan if cp_plan is not None else self._cp_plan
+
+ if config.context_parallel_config is not None:
+ apply_context_parallel(self, config.context_parallel_config, cp_plan)
+
+ self._parallel_config = config
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_parallel_config"):
+ continue
+ processor._parallel_config = config
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1734,7 +1801,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
```py
from diffusers import UNet2DConditionModel
- model_id = "runwayml/stable-diffusion-v1-5"
+ model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
unet.num_parameters(only_trainable=True)
859520964
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index b60f0636e6..15408a4b15 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -18,6 +18,7 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_allegro import AllegroTransformer3DModel
from .transformer_bria import BriaTransformer2DModel
+ from .transformer_bria_fibo import BriaFiboTransformer2DModel
from .transformer_chroma import ChromaTransformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_cogview4 import CogView4Transformer2DModel
@@ -27,11 +28,15 @@ if is_torch_available():
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel
+ from .transformer_hunyuanimage import HunyuanImageTransformer2DModel
+ from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
+ from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
+ from .transformer_sana_video import SanaVideoTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/auraflow_transformer_2d.py b/src/diffusers/models/transformers/auraflow_transformer_2d.py
index a8d275d142..bf6d9e1b38 100644
--- a/src/diffusers/models/transformers/auraflow_transformer_2d.py
+++ b/src/diffusers/models/transformers/auraflow_transformer_2d.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -92,7 +92,7 @@ class AuraFlowPatchEmbed(nn.Module):
return selected_indices
- def forward(self, latent):
+ def forward(self, latent) -> torch.Tensor:
batch_size, num_channels, height, width = latent.size()
latent = latent.view(
batch_size,
@@ -173,7 +173,7 @@ class AuraFlowSingleTransformerBlock(nn.Module):
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> torch.Tensor:
residual = hidden_states
attention_kwargs = attention_kwargs or {}
@@ -242,7 +242,7 @@ class AuraFlowJointTransformerBlock(nn.Module):
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
residual_context = encoder_hidden_states
attention_kwargs = attention_kwargs or {}
@@ -431,11 +431,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -455,11 +451,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -472,7 +464,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
timestep: torch.LongTensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/cogvideox_transformer_3d.py b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
index a8c98bccb8..9e0afdee66 100644
--- a/src/diffusers/models/transformers/cogvideox_transformer_3d.py
+++ b/src/diffusers/models/transformers/cogvideox_transformer_3d.py
@@ -122,7 +122,7 @@ class CogVideoXBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
attention_kwargs = attention_kwargs or {}
@@ -397,11 +397,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -421,11 +417,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
@@ -441,7 +433,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cac
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/consisid_transformer_3d.py b/src/diffusers/models/transformers/consisid_transformer_3d.py
index 41632dbd47..91fe811f00 100644
--- a/src/diffusers/models/transformers/consisid_transformer_3d.py
+++ b/src/diffusers/models/transformers/consisid_transformer_3d.py
@@ -315,7 +315,7 @@ class ConsisIDBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -691,7 +691,7 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
id_cond: Optional[torch.Tensor] = None,
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/hunyuan_transformer_2d.py b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
index f634718788..fbe9fe8df9 100644
--- a/src/diffusers/models/transformers/hunyuan_transformer_2d.py
+++ b/src/diffusers/models/transformers/hunyuan_transformer_2d.py
@@ -324,11 +324,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -348,11 +344,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index 84b1175386..bed5e69c2d 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -124,7 +124,7 @@ class LuminaNextDiTBlock(nn.Module):
encoder_mask: torch.Tensor,
temb: torch.Tensor,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- ):
+ ) -> torch.Tensor:
"""
Perform a forward pass through the LuminaNextDiTBlock.
@@ -297,7 +297,7 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
image_rotary_emb: torch.Tensor,
cross_attention_kwargs: Dict[str, Any] = None,
return_dict=True,
- ) -> torch.Tensor:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
Forward pass of LuminaNextDiT.
diff --git a/src/diffusers/models/transformers/pixart_transformer_2d.py b/src/diffusers/models/transformers/pixart_transformer_2d.py
index 40a14bfd9b..5a22144228 100644
--- a/src/diffusers/models/transformers/pixart_transformer_2d.py
+++ b/src/diffusers/models/transformers/pixart_transformer_2d.py
@@ -258,11 +258,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -282,11 +278,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py
index 969e6db122..ac9b3fca41 100644
--- a/src/diffusers/models/transformers/stable_audio_transformer.py
+++ b/src/diffusers/models/transformers/stable_audio_transformer.py
@@ -18,7 +18,6 @@ from typing import Dict, Optional, Union
import numpy as np
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
diff --git a/src/diffusers/models/transformers/transformer_bria.py b/src/diffusers/models/transformers/transformer_bria.py
index 27a9941501..d54679306e 100644
--- a/src/diffusers/models/transformers/transformer_bria.py
+++ b/src/diffusers/models/transformers/transformer_bria.py
@@ -120,6 +120,7 @@ def get_1d_rotary_pos_embed(
class BriaAttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -161,7 +162,12 @@ class BriaAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
- query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -472,7 +478,7 @@ class BriaSingleTransformerBlock(nn.Module):
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
@@ -588,7 +594,7 @@ class BriaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
return_dict: bool = True,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`BriaTransformer2DModel`] forward method.
diff --git a/src/diffusers/models/transformers/transformer_bria_fibo.py b/src/diffusers/models/transformers/transformer_bria_fibo.py
new file mode 100644
index 0000000000..09f7961932
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_bria_fibo.py
@@ -0,0 +1,655 @@
+# Copyright (c) Bria.ai. All rights reserved.
+#
+# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
+# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
+#
+# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
+# indicate if changes were made, and do not use the material for commercial purposes.
+#
+# See the license for further details.
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...models.attention_processor import Attention
+from ...models.embeddings import TimestepEmbedding, apply_rotary_emb, get_1d_rotary_pos_embed, get_timestep_embedding
+from ...models.modeling_outputs import Transformer2DModelOutput
+from ...models.modeling_utils import ModelMixin
+from ...models.transformers.transformer_bria import BriaAttnProcessor
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _get_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "BriaFiboAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+# Copied from diffusers.models.transformers.transformer_flux.FluxAttnProcessor with FluxAttnProcessor->BriaFiboAttnProcessor, FluxAttention->BriaFiboAttention
+class BriaFiboAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "BriaFiboAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+# Based on https://github.com/huggingface/diffusers/blob/55d49d4379007740af20629bb61aba9546c6b053/src/diffusers/models/transformers/transformer_flux.py
+class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = BriaFiboAttnProcessor
+ _available_processors = [BriaFiboAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
+class BriaFiboEmbedND(torch.nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ freqs_dtype = torch.float32 if is_mps else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class BriaFiboSingleTransformerBlock(nn.Module):
+ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
+ super().__init__()
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
+
+ self.norm = AdaLayerNormZeroSingle(dim)
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
+
+ processor = BriaAttnProcessor()
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ processor=processor,
+ qk_norm="rms_norm",
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.Tensor:
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ return hidden_states
+
+
+class BriaFiboTextProjection(nn.Module):
+ def __init__(self, in_features, hidden_size):
+ super().__init__()
+ self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
+
+ def forward(self, caption):
+ hidden_states = self.linear(caption)
+ return hidden_states
+
+
+@maybe_allow_in_graph
+# Based on from diffusers.models.transformers.transformer_flux.FluxTransformerBlock
+class BriaFiboTransformerBlock(nn.Module):
+ def __init__(
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
+ ):
+ super().__init__()
+
+ self.norm1 = AdaLayerNormZero(dim)
+ self.norm1_context = AdaLayerNormZero(dim)
+
+ self.attn = BriaFiboAttention(
+ query_dim=dim,
+ added_kv_proj_dim=dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=False,
+ bias=True,
+ processor=BriaFiboAttnProcessor(),
+ eps=eps,
+ )
+
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+ joint_attention_kwargs = joint_attention_kwargs or {}
+
+ # Attention.
+ attention_outputs = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+
+ if len(attention_outputs) == 2:
+ attn_output, context_attn_output = attention_outputs
+ elif len(attention_outputs) == 3:
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
+
+ # Process attention outputs for the `hidden_states`.
+ attn_output = gate_msa.unsqueeze(1) * attn_output
+ hidden_states = hidden_states + attn_output
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
+
+ hidden_states = hidden_states + ff_output
+ if len(attention_outputs) == 3:
+ hidden_states = hidden_states + ip_attn_output
+
+ # Process attention outputs for the `encoder_hidden_states`.
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
+
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+ if encoder_hidden_states.dtype == torch.float16:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return encoder_hidden_states, hidden_states
+
+
+class BriaFiboTimesteps(nn.Module):
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
+ ):
+ super().__init__()
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+ self.scale = scale
+ self.time_theta = time_theta
+
+ def forward(self, timesteps):
+ t_emb = get_timestep_embedding(
+ timesteps,
+ self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ scale=self.scale,
+ max_period=self.time_theta,
+ )
+ return t_emb
+
+
+class BriaFiboTimestepProjEmbeddings(nn.Module):
+ def __init__(self, embedding_dim, time_theta):
+ super().__init__()
+
+ self.time_proj = BriaFiboTimesteps(
+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
+ )
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(self, timestep, dtype):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=dtype)) # (N, D)
+ return timesteps_emb
+
+
+class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
+ """
+ Parameters:
+ patch_size (`int`): Patch size to turn the input data into small patches.
+ in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
+ num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
+ num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
+ attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
+ num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
+ joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
+ guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
+ ...
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: int = 1,
+ in_channels: int = 64,
+ num_layers: int = 19,
+ num_single_layers: int = 38,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 24,
+ joint_attention_dim: int = 4096,
+ pooled_projection_dim: int = None,
+ guidance_embeds: bool = False,
+ axes_dims_rope: List[int] = [16, 56, 56],
+ rope_theta=10000,
+ time_theta=10000,
+ text_encoder_dim: int = 2048,
+ ):
+ super().__init__()
+ self.out_channels = in_channels
+ self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
+
+ self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
+
+ self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
+
+ if guidance_embeds:
+ self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
+
+ self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
+ self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BriaFiboTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_layers)
+ ]
+ )
+
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ BriaFiboSingleTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.config.num_attention_heads,
+ attention_head_dim=self.config.attention_head_dim,
+ )
+ for i in range(self.config.num_single_layers)
+ ]
+ )
+
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
+
+ self.gradient_checkpointing = False
+
+ caption_projection = [
+ BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
+ for i in range(self.config.num_layers + self.config.num_single_layers)
+ ]
+ self.caption_projection = nn.ModuleList(caption_projection)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ text_encoder_layers: list = None,
+ pooled_projections: torch.Tensor = None,
+ timestep: torch.LongTensor = None,
+ img_ids: torch.Tensor = None,
+ txt_ids: torch.Tensor = None,
+ guidance: torch.Tensor = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
+ """
+
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input `hidden_states`.
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
+ from the embeddings of input conditions.
+ timestep ( `torch.LongTensor`):
+ Used to indicate denoising step.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
+ tuple.
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ if joint_attention_kwargs is not None:
+ joint_attention_kwargs = joint_attention_kwargs.copy()
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+ hidden_states = self.x_embedder(hidden_states)
+
+ timestep = timestep.to(hidden_states.dtype)
+ if guidance is not None:
+ guidance = guidance.to(hidden_states.dtype)
+ else:
+ guidance = None
+
+ temb = self.time_embed(timestep, dtype=hidden_states.dtype)
+
+ if guidance:
+ temb += self.guidance_embed(guidance, dtype=hidden_states.dtype)
+
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
+
+ if len(txt_ids.shape) == 3:
+ txt_ids = txt_ids[0]
+
+ if len(img_ids.shape) == 3:
+ img_ids = img_ids[0]
+
+ ids = torch.cat((txt_ids, img_ids), dim=0)
+ image_rotary_emb = self.pos_embed(ids)
+
+ new_text_encoder_layers = []
+ for i, text_encoder_layer in enumerate(text_encoder_layers):
+ text_encoder_layer = self.caption_projection[i](text_encoder_layer)
+ new_text_encoder_layers.append(text_encoder_layer)
+ text_encoder_layers = new_text_encoder_layers
+
+ block_id = 0
+ for index_block, block in enumerate(self.transformer_blocks):
+ current_text_encoder_layer = text_encoder_layers[block_id]
+ encoder_hidden_states = torch.cat(
+ [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
+ )
+ block_id += 1
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ encoder_hidden_states, hidden_states = block(
+ hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ for index_block, block in enumerate(self.single_transformer_blocks):
+ current_text_encoder_layer = text_encoder_layers[block_id]
+ encoder_hidden_states = torch.cat(
+ [encoder_hidden_states[:, :, : self.inner_dim // 2], current_text_encoder_layer], dim=-1
+ )
+ block_id += 1
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ temb,
+ image_rotary_emb,
+ joint_attention_kwargs,
+ )
+
+ else:
+ hidden_states = block(
+ hidden_states=hidden_states,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ joint_attention_kwargs=joint_attention_kwargs,
+ )
+
+ encoder_hidden_states = hidden_states[:, : encoder_hidden_states.shape[1], ...]
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+
+ hidden_states = self.norm_out(hidden_states, temb)
+ output = self.proj_out(hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
index 5823ae9d3d..2ef3643daf 100644
--- a/src/diffusers/models/transformers/transformer_chroma.py
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
"""
The Transformer model introduced in Flux, modified for Chroma.
- Reference: https://huggingface.co/lodestones/Chroma
+ Reference: https://huggingface.co/lodestones/Chroma1-HD
Args:
patch_size (`int`, defaults to `1`):
diff --git a/src/diffusers/models/transformers/transformer_cogview3plus.py b/src/diffusers/models/transformers/transformer_cogview3plus.py
index 77f15f6ca6..7356f4a606 100644
--- a/src/diffusers/models/transformers/transformer_cogview3plus.py
+++ b/src/diffusers/models/transformers/transformer_cogview3plus.py
@@ -13,7 +13,7 @@
# limitations under the License.
-from typing import Dict, Union
+from typing import Dict, Tuple, Union
import torch
import torch.nn as nn
@@ -79,7 +79,7 @@ class CogView3PlusTransformerBlock(nn.Module):
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
@@ -293,7 +293,7 @@ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index 25dcfa14cc..64e9a538a7 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -494,7 +494,7 @@ class CogView4TransformerBlock(nn.Module):
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
@@ -717,7 +717,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 7ab371a1a1..16c526f437 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -22,8 +22,9 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -73,6 +74,7 @@ def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_st
class FluxAttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -114,7 +116,12 @@ class FluxAttnProcessor:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
hidden_states = dispatch_attention_fn(
- query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -136,6 +143,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
_attention_backend = None
+ _parallel_config = None
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
@@ -220,6 +228,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -252,6 +261,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
@@ -556,6 +566,15 @@ class FluxTransformer2DModel(
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -698,7 +717,11 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
- image_rotary_emb = self.pos_embed(ids)
+ if is_torch_npu_available():
+ freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
+ image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
+ else:
+ image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
diff --git a/src/diffusers/models/transformers/transformer_hidream_image.py b/src/diffusers/models/transformers/transformer_hidream_image.py
index 77902dcf58..4a5aee29ab 100644
--- a/src/diffusers/models/transformers/transformer_hidream_image.py
+++ b/src/diffusers/models/transformers/transformer_hidream_image.py
@@ -55,7 +55,7 @@ class HiDreamImageTimestepEmbed(nn.Module):
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
- def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None):
+ def forward(self, timesteps: torch.Tensor, wdtype: Optional[torch.dtype] = None) -> torch.Tensor:
t_emb = self.time_proj(timesteps).to(dtype=wdtype)
t_emb = self.timestep_embedder(t_emb)
return t_emb
@@ -87,7 +87,7 @@ class HiDreamImagePatchEmbed(nn.Module):
self.out_channels = out_channels
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True)
- def forward(self, latent):
+ def forward(self, latent) -> torch.Tensor:
latent = self.proj(latent)
return latent
@@ -534,7 +534,7 @@ class HiDreamImageTransformerBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
wtype = hidden_states.dtype
(
shift_msa_i,
@@ -592,7 +592,7 @@ class HiDreamBlock(nn.Module):
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: torch.Tensor = None,
- ) -> torch.Tensor:
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
return self.block(
hidden_states=hidden_states,
hidden_states_masks=hidden_states_masks,
@@ -786,7 +786,7 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index 6944a6c536..bc857ccab4 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -529,7 +529,7 @@ class HunyuanVideoSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
*args,
**kwargs,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -684,7 +684,7 @@ class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
token_replace_emb: torch.Tensor = None,
num_tokens: int = None,
- ) -> torch.Tensor:
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
@@ -1038,7 +1038,7 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
index c2eb7fd2a7..60b40fff3c 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video_framepack.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -216,7 +216,7 @@ class HunyuanVideoFramepackTransformer3DModel(
indices_latents_history_4x: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
- ):
+ ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
diff --git a/src/diffusers/models/transformers/transformer_hunyuanimage.py b/src/diffusers/models/transformers/transformer_hunyuanimage.py
new file mode 100644
index 0000000000..7f37bf815b
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_hunyuanimage.py
@@ -0,0 +1,971 @@
+# Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from diffusers.loaders import FromOriginalModelMixin
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import FeedForward
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention, AttentionProcessor
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ CombinedTimestepTextProjEmbeddings,
+ TimestepEmbedding,
+ Timesteps,
+ get_1d_rotary_pos_embed,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class HunyuanImageAttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "HunyuanImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ # 1. QKV projections
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ query = query.unflatten(2, (attn.heads, -1)) # batch_size, seq_len, heads, head_dim
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ # 2. QK normalization
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # 3. Rotational positional embeddings applied to latent stream
+ if image_rotary_emb is not None:
+ from ..embeddings import apply_rotary_emb
+
+ if attn.add_q_proj is None and encoder_hidden_states is not None:
+ query = torch.cat(
+ [
+ apply_rotary_emb(
+ query[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1
+ ),
+ query[:, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=1,
+ )
+ key = torch.cat(
+ [
+ apply_rotary_emb(key[:, : -encoder_hidden_states.shape[1]], image_rotary_emb, sequence_dim=1),
+ key[:, -encoder_hidden_states.shape[1] :],
+ ],
+ dim=1,
+ )
+ else:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # 4. Encoder condition QKV projection and normalization
+ if attn.add_q_proj is not None and encoder_hidden_states is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
+
+ if attn.norm_added_q is not None:
+ encoder_query = attn.norm_added_q(encoder_query)
+ if attn.norm_added_k is not None:
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([query, encoder_query], dim=1)
+ key = torch.cat([key, encoder_key], dim=1)
+ value = torch.cat([value, encoder_value], dim=1)
+
+ # 5. Attention
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ # 6. Output projection
+ if encoder_hidden_states is not None:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : -encoder_hidden_states.shape[1]],
+ hidden_states[:, -encoder_hidden_states.shape[1] :],
+ )
+
+ if getattr(attn, "to_out", None) is not None:
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if getattr(attn, "to_add_out", None) is not None:
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanImagePatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size: Union[Tuple[int, int], Tuple[int, int, int]] = (16, 16),
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ ) -> None:
+ super().__init__()
+
+ self.patch_size = patch_size
+
+ if len(patch_size) == 2:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ elif len(patch_size) == 3:
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ else:
+ raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {len(patch_size)}")
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ return hidden_states
+
+
+class HunyuanImageByT5TextProjection(nn.Module):
+ def __init__(self, in_features: int, hidden_size: int, out_features: int):
+ super().__init__()
+ self.norm = nn.LayerNorm(in_features)
+ self.linear_1 = nn.Linear(in_features, hidden_size)
+ self.linear_2 = nn.Linear(hidden_size, hidden_size)
+ self.linear_3 = nn.Linear(hidden_size, out_features)
+ self.act_fn = nn.GELU()
+
+ def forward(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.norm(encoder_hidden_states)
+ hidden_states = self.linear_1(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_2(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = self.linear_3(hidden_states)
+ return hidden_states
+
+
+class HunyuanImageAdaNorm(nn.Module):
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
+ super().__init__()
+
+ out_features = out_features or 2 * in_features
+ self.linear = nn.Linear(in_features, out_features)
+ self.nonlinearity = nn.SiLU()
+
+ def forward(
+ self, temb: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ temb = self.linear(self.nonlinearity(temb))
+ gate_msa, gate_mlp = temb.chunk(2, dim=1)
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
+ return gate_msa, gate_mlp
+
+
+class HunyuanImageCombinedTimeGuidanceEmbedding(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ guidance_embeds: bool = False,
+ use_meanflow: bool = False,
+ ):
+ super().__init__()
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.use_meanflow = use_meanflow
+
+ self.time_proj_r = None
+ self.timestep_embedder_r = None
+ if use_meanflow:
+ self.time_proj_r = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder_r = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_embedder = None
+ if guidance_embeds:
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ timestep_r: Optional[torch.Tensor] = None,
+ guidance: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=timestep.dtype))
+
+ if timestep_r is not None:
+ timesteps_proj_r = self.time_proj_r(timestep_r)
+ timesteps_emb_r = self.timestep_embedder_r(timesteps_proj_r.to(dtype=timestep.dtype))
+ timesteps_emb = (timesteps_emb + timesteps_emb_r) / 2
+
+ if self.guidance_embedder is not None:
+ guidance_proj = self.time_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=timestep.dtype))
+ conditioning = timesteps_emb + guidance_emb
+ else:
+ conditioning = timesteps_emb
+
+ return conditioning
+
+
+# IndividualTokenRefinerBlock
+@maybe_allow_in_graph
+class HunyuanImageIndividualTokenRefinerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int, # 28
+ attention_head_dim: int, # 128
+ mlp_width_ratio: str = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ bias=attention_bias,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
+
+ self.norm_out = HunyuanImageAdaNorm(hidden_size, 2 * hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ norm_hidden_states = self.norm1(hidden_states)
+
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ )
+
+ gate_msa, gate_mlp = self.norm_out(temb)
+ hidden_states = hidden_states + attn_output * gate_msa
+
+ ff_output = self.ff(self.norm2(hidden_states))
+ hidden_states = hidden_states + ff_output * gate_mlp
+
+ return hidden_states
+
+
+class HunyuanImageIndividualTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_width_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.refiner_blocks = nn.ModuleList(
+ [
+ HunyuanImageIndividualTokenRefinerBlock(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_width_ratio=mlp_width_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> None:
+ self_attn_mask = None
+ if attention_mask is not None:
+ batch_size = attention_mask.shape[0]
+ seq_len = attention_mask.shape[1]
+ attention_mask = attention_mask.to(hidden_states.device)
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
+ self_attn_mask[:, :, :, 0] = True
+
+ for block in self.refiner_blocks:
+ hidden_states = block(hidden_states, temb, self_attn_mask)
+
+ return hidden_states
+
+
+# txt_in
+class HunyuanImageTokenRefiner(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_layers: int,
+ mlp_ratio: float = 4.0,
+ mlp_drop_rate: float = 0.0,
+ attention_bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
+ embedding_dim=hidden_size, pooled_projection_dim=in_channels
+ )
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
+ self.token_refiner = HunyuanImageIndividualTokenRefiner(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_layers=num_layers,
+ mlp_width_ratio=mlp_ratio,
+ mlp_drop_rate=mlp_drop_rate,
+ attention_bias=attention_bias,
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ) -> torch.Tensor:
+ if attention_mask is None:
+ pooled_hidden_states = hidden_states.mean(dim=1)
+ else:
+ original_dtype = hidden_states.dtype
+ mask_float = attention_mask.float().unsqueeze(-1)
+ pooled_hidden_states = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
+ pooled_hidden_states = pooled_hidden_states.to(original_dtype)
+
+ temb = self.time_text_embed(timestep, pooled_hidden_states)
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
+
+ return hidden_states
+
+
+class HunyuanImageRotaryPosEmbed(nn.Module):
+ def __init__(
+ self, patch_size: Union[Tuple, List[int]], rope_dim: Union[Tuple, List[int]], theta: float = 256.0
+ ) -> None:
+ super().__init__()
+
+ if not isinstance(patch_size, (tuple, list)) or len(patch_size) not in [2, 3]:
+ raise ValueError(f"patch_size must be a tuple or list of length 2 or 3, got {patch_size}")
+
+ if not isinstance(rope_dim, (tuple, list)) or len(rope_dim) not in [2, 3]:
+ raise ValueError(f"rope_dim must be a tuple or list of length 2 or 3, got {rope_dim}")
+
+ if not len(patch_size) == len(rope_dim):
+ raise ValueError(f"patch_size and rope_dim must have the same length, got {patch_size} and {rope_dim}")
+
+ self.patch_size = patch_size
+ self.rope_dim = rope_dim
+ self.theta = theta
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if hidden_states.ndim == 5:
+ _, _, frame, height, width = hidden_states.shape
+ patch_size_frame, patch_size_height, patch_size_width = self.patch_size
+ rope_sizes = [frame // patch_size_frame, height // patch_size_height, width // patch_size_width]
+ elif hidden_states.ndim == 4:
+ _, _, height, width = hidden_states.shape
+ patch_size_height, patch_size_width = self.patch_size
+ rope_sizes = [height // patch_size_height, width // patch_size_width]
+ else:
+ raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
+
+ axes_grids = []
+ for i in range(len(rope_sizes)):
+ grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
+ axes_grids.append(grid)
+ grid = torch.meshgrid(*axes_grids, indexing="ij") # dim x [H, W]
+ grid = torch.stack(grid, dim=0) # [2, H, W]
+
+ freqs = []
+ for i in range(len(rope_sizes)):
+ freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
+ freqs.append(freq)
+
+ freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
+ freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
+ return freqs_cos, freqs_sin
+
+
+@maybe_allow_in_graph
+class HunyuanImageSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float = 4.0,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+ mlp_dim = int(hidden_size * mlp_ratio)
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ bias=True,
+ processor=HunyuanImageAttnProcessor(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ pre_only=True,
+ )
+
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
+ self.act_mlp = nn.GELU(approximate="tanh")
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ text_seq_length = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
+
+ residual = hidden_states
+
+ # 1. Input normalization
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ norm_hidden_states, norm_encoder_hidden_states = (
+ norm_hidden_states[:, :-text_seq_length, :],
+ norm_hidden_states[:, -text_seq_length:, :],
+ )
+
+ # 2. Attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
+
+ # 3. Modulation and residual connection
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
+ hidden_states = hidden_states + residual
+
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, :-text_seq_length, :],
+ hidden_states[:, -text_seq_length:, :],
+ )
+ return hidden_states, encoder_hidden_states
+
+
+@maybe_allow_in_graph
+class HunyuanImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ mlp_ratio: float,
+ qk_norm: str = "rms_norm",
+ ) -> None:
+ super().__init__()
+
+ hidden_size = num_attention_heads * attention_head_dim
+
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
+
+ self.attn = Attention(
+ query_dim=hidden_size,
+ cross_attention_dim=None,
+ added_kv_proj_dim=hidden_size,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=hidden_size,
+ context_pre_only=False,
+ bias=True,
+ processor=HunyuanImageAttnProcessor(),
+ qk_norm=qk_norm,
+ eps=1e-6,
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # 1. Input normalization
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
+ encoder_hidden_states, emb=temb
+ )
+
+ # 2. Joint attention
+ attn_output, context_attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 3. Modulation and residual connection
+ hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
+
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
+
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
+
+ # 4. Feed-forward
+ ff_output = self.ff(norm_hidden_states)
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
+
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
+
+ return hidden_states, encoder_hidden_states
+
+
+class HunyuanImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ The Transformer model used in [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1).
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `24`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of dual-stream blocks to use.
+ num_single_layers (`int`, defaults to `40`):
+ The number of layers of single-stream blocks to use.
+ num_refiner_layers (`int`, defaults to `2`):
+ The number of layers of refiner blocks to use.
+ mlp_ratio (`float`, defaults to `4.0`):
+ The ratio of the hidden layer size to the input size in the feedforward network.
+ patch_size (`int`, defaults to `2`):
+ The size of the spatial patches to use in the patch embedding layer.
+ patch_size_t (`int`, defaults to `1`):
+ The size of the tmeporal patches to use in the patch embedding layer.
+ qk_norm (`str`, defaults to `rms_norm`):
+ The normalization to use for the query and key projections in the attention layers.
+ guidance_embeds (`bool`, defaults to `True`):
+ Whether to use guidance embeddings in the model.
+ text_embed_dim (`int`, defaults to `4096`):
+ Input dimension of text embeddings from the text encoder.
+ pooled_projection_dim (`int`, defaults to `768`):
+ The dimension of the pooled projection of the text embeddings.
+ rope_theta (`float`, defaults to `256.0`):
+ The value of theta to use in the RoPE layer.
+ rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
+ The dimensions of the axes to use in the RoPE layer.
+ image_condition_type (`str`, *optional*, defaults to `None`):
+ The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
+ image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
+ tokens in the latent stream and apply conditioning.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
+ _no_split_modules = [
+ "HunyuanImageTransformerBlock",
+ "HunyuanImageSingleTransformerBlock",
+ "HunyuanImagePatchEmbed",
+ "HunyuanImageTokenRefiner",
+ ]
+ _repeated_blocks = [
+ "HunyuanImageTransformerBlock",
+ "HunyuanImageSingleTransformerBlock",
+ ]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 64,
+ out_channels: int = 64,
+ num_attention_heads: int = 28,
+ attention_head_dim: int = 128,
+ num_layers: int = 20,
+ num_single_layers: int = 40,
+ num_refiner_layers: int = 2,
+ mlp_ratio: float = 4.0,
+ patch_size: Tuple[int, int] = (1, 1),
+ qk_norm: str = "rms_norm",
+ guidance_embeds: bool = False,
+ text_embed_dim: int = 3584,
+ text_embed_2_dim: Optional[int] = None,
+ rope_theta: float = 256.0,
+ rope_axes_dim: Tuple[int] = (64, 64),
+ use_meanflow: bool = False,
+ ) -> None:
+ super().__init__()
+
+ if not (isinstance(patch_size, (tuple, list)) and len(patch_size) in [2, 3]):
+ raise ValueError(f"patch_size must be a tuple of length 2 or 3, got {patch_size}")
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Latent and condition embedders
+ self.x_embedder = HunyuanImagePatchEmbed(patch_size, in_channels, inner_dim)
+ self.context_embedder = HunyuanImageTokenRefiner(
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
+ )
+
+ if text_embed_2_dim is not None:
+ self.context_embedder_2 = HunyuanImageByT5TextProjection(text_embed_2_dim, 2048, inner_dim)
+ else:
+ self.context_embedder_2 = None
+
+ self.time_guidance_embed = HunyuanImageCombinedTimeGuidanceEmbedding(inner_dim, guidance_embeds, use_meanflow)
+
+ # 2. RoPE
+ self.rope = HunyuanImageRotaryPosEmbed(patch_size, rope_axes_dim, rope_theta)
+
+ # 3. Dual stream transformer blocks
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ HunyuanImageTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Single stream transformer blocks
+ self.single_transformer_blocks = nn.ModuleList(
+ [
+ HunyuanImageSingleTransformerBlock(
+ num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
+ )
+ for _ in range(num_single_layers)
+ ]
+ )
+
+ # 5. Output projection
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
+
+ self.gradient_checkpointing = False
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_attention_mask: torch.Tensor,
+ timestep_r: Optional[torch.LongTensor] = None,
+ encoder_hidden_states_2: Optional[torch.Tensor] = None,
+ encoder_attention_mask_2: Optional[torch.Tensor] = None,
+ guidance: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ if hidden_states.ndim == 4:
+ batch_size, channels, height, width = hidden_states.shape
+ sizes = (height, width)
+ elif hidden_states.ndim == 5:
+ batch_size, channels, frame, height, width = hidden_states.shape
+ sizes = (frame, height, width)
+ else:
+ raise ValueError(f"hidden_states must be a 4D or 5D tensor, got {hidden_states.shape}")
+
+ post_patch_sizes = tuple(d // p for d, p in zip(sizes, self.config.patch_size))
+
+ # 1. RoPE
+ image_rotary_emb = self.rope(hidden_states)
+
+ # 2. Conditional embeddings
+ encoder_attention_mask = encoder_attention_mask.bool()
+ temb = self.time_guidance_embed(timestep, guidance=guidance, timestep_r=timestep_r)
+ hidden_states = self.x_embedder(hidden_states)
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
+
+ if self.context_embedder_2 is not None and encoder_hidden_states_2 is not None:
+ encoder_hidden_states_2 = self.context_embedder_2(encoder_hidden_states_2)
+
+ encoder_attention_mask_2 = encoder_attention_mask_2.bool()
+
+ # reorder and combine text tokens: combine valid tokens first, then padding
+ new_encoder_hidden_states = []
+ new_encoder_attention_mask = []
+
+ for text, text_mask, text_2, text_mask_2 in zip(
+ encoder_hidden_states, encoder_attention_mask, encoder_hidden_states_2, encoder_attention_mask_2
+ ):
+ # Concatenate: [valid_mllm, valid_byt5, invalid_mllm, invalid_byt5]
+ new_encoder_hidden_states.append(
+ torch.cat(
+ [
+ text_2[text_mask_2], # valid byt5
+ text[text_mask], # valid mllm
+ text_2[~text_mask_2], # invalid byt5
+ text[~text_mask], # invalid mllm
+ ],
+ dim=0,
+ )
+ )
+
+ # Apply same reordering to attention masks
+ new_encoder_attention_mask.append(
+ torch.cat(
+ [
+ text_mask_2[text_mask_2],
+ text_mask[text_mask],
+ text_mask_2[~text_mask_2],
+ text_mask[~text_mask],
+ ],
+ dim=0,
+ )
+ )
+
+ encoder_hidden_states = torch.stack(new_encoder_hidden_states)
+ encoder_attention_mask = torch.stack(new_encoder_attention_mask)
+
+ attention_mask = torch.nn.functional.pad(encoder_attention_mask, (hidden_states.shape[1], 0), value=True)
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ # 3. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ else:
+ for block in self.transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ for block in self.single_transformer_blocks:
+ hidden_states, encoder_hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ temb,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # 4. Output projection
+ hidden_states = self.norm_out(hidden_states, temb)
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. unpatchify
+ # reshape: [batch_size, *post_patch_dims, channels, *patch_size]
+ out_channels = self.config.out_channels
+ reshape_dims = [batch_size] + list(post_patch_sizes) + [out_channels] + list(self.config.patch_size)
+ hidden_states = hidden_states.reshape(*reshape_dims)
+
+ # create permutation pattern: batch, channels, then interleave post_patch and patch dims
+ # For 4D: [0, 3, 1, 4, 2, 5] -> batch, channels, post_patch_height, patch_size_height, post_patch_width, patch_size_width
+ # For 5D: [0, 4, 1, 5, 2, 6, 3, 7] -> batch, channels, post_patch_frame, patch_size_frame, post_patch_height, patch_size_height, post_patch_width, patch_size_width
+ ndim = len(post_patch_sizes)
+ permute_pattern = [0, ndim + 1] # batch, channels
+ for i in range(ndim):
+ permute_pattern.extend([i + 1, ndim + 2 + i]) # post_patch_sizes[i], patch_sizes[i]
+ hidden_states = hidden_states.permute(*permute_pattern)
+
+ # flatten patch dimensions: flatten each (post_patch_size, patch_size) pair
+ # batch_size, channels, post_patch_sizes[0] * patch_sizes[0], post_patch_sizes[1] * patch_sizes[1], ...
+ final_dims = [batch_size, out_channels] + [
+ post_patch * patch for post_patch, patch in zip(post_patch_sizes, self.config.patch_size)
+ ]
+ hidden_states = hidden_states.reshape(*final_dims)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (hidden_states,)
+
+ return Transformer2DModelOutput(sample=hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py
new file mode 100644
index 0000000000..316e79da4f
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_kandinsky.py
@@ -0,0 +1,669 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import (
+ logging,
+)
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn
+from ..cache_utils import CacheMixin
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_freqs(dim, max_period=10000.0):
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
+ return freqs
+
+
+def fractal_flatten(x, rope, shape, block_mask=False):
+ if block_mask:
+ pixel_size = 8
+ x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
+ rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
+ x = x.flatten(1, 2)
+ rope = rope.flatten(1, 2)
+ else:
+ x = x.flatten(1, 3)
+ rope = rope.flatten(1, 3)
+ return x, rope
+
+
+def fractal_unflatten(x, shape, block_mask=False):
+ if block_mask:
+ pixel_size = 8
+ x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:])
+ x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
+ else:
+ x = x.reshape(*shape, *x.shape[2:])
+ return x
+
+
+def local_patching(x, shape, group_size, dim=0):
+ batch_size, duration, height, width = shape
+ g1, g2, g3 = group_size
+ x = x.reshape(
+ *x.shape[:dim],
+ duration // g1,
+ g1,
+ height // g2,
+ g2,
+ width // g3,
+ g3,
+ *x.shape[dim + 3 :],
+ )
+ x = x.permute(
+ *range(len(x.shape[:dim])),
+ dim,
+ dim + 2,
+ dim + 4,
+ dim + 1,
+ dim + 3,
+ dim + 5,
+ *range(dim + 6, len(x.shape)),
+ )
+ x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
+ return x
+
+
+def local_merge(x, shape, group_size, dim=0):
+ batch_size, duration, height, width = shape
+ g1, g2, g3 = group_size
+ x = x.reshape(
+ *x.shape[:dim],
+ duration // g1,
+ height // g2,
+ width // g3,
+ g1,
+ g2,
+ g3,
+ *x.shape[dim + 2 :],
+ )
+ x = x.permute(
+ *range(len(x.shape[:dim])),
+ dim,
+ dim + 3,
+ dim + 1,
+ dim + 4,
+ dim + 2,
+ dim + 5,
+ *range(dim + 6, len(x.shape)),
+ )
+ x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
+ return x
+
+
+def nablaT_v2(
+ q: Tensor,
+ k: Tensor,
+ sta: Tensor,
+ thr: float = 0.9,
+):
+ if _CAN_USE_FLEX_ATTN:
+ from torch.nn.attention.flex_attention import BlockMask
+ else:
+ raise ValueError("Nabla attention is not supported with this version of PyTorch")
+
+ q = q.transpose(1, 2).contiguous()
+ k = k.transpose(1, 2).contiguous()
+
+ # Map estimation
+ B, h, S, D = q.shape
+ s1 = S // 64
+ qa = q.reshape(B, h, s1, 64, D).mean(-2)
+ ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
+ map = qa @ ka
+
+ map = torch.softmax(map / math.sqrt(D), dim=-1)
+ # Map binarization
+ vals, inds = map.sort(-1)
+ cvals = vals.cumsum_(-1)
+ mask = (cvals >= 1 - thr).int()
+ mask = mask.gather(-1, inds.argsort(-1))
+
+ mask = torch.logical_or(mask, sta)
+
+ # BlockMask creation
+ kv_nb = mask.sum(-1).to(torch.int32)
+ kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
+ return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None)
+
+
+class Kandinsky5TimeEmbeddings(nn.Module):
+ def __init__(self, model_dim, time_dim, max_period=10000.0):
+ super().__init__()
+ assert model_dim % 2 == 0
+ self.model_dim = model_dim
+ self.max_period = max_period
+ self.freqs = get_freqs(self.model_dim // 2, self.max_period)
+ self.in_layer = nn.Linear(model_dim, time_dim, bias=True)
+ self.activation = nn.SiLU()
+ self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
+
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
+ def forward(self, time):
+ args = torch.outer(time, self.freqs.to(device=time.device))
+ time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
+ return time_embed
+
+
+class Kandinsky5TextEmbeddings(nn.Module):
+ def __init__(self, text_dim, model_dim):
+ super().__init__()
+ self.in_layer = nn.Linear(text_dim, model_dim, bias=True)
+ self.norm = nn.LayerNorm(model_dim, elementwise_affine=True)
+
+ def forward(self, text_embed):
+ text_embed = self.in_layer(text_embed)
+ return self.norm(text_embed).type_as(text_embed)
+
+
+class Kandinsky5VisualEmbeddings(nn.Module):
+ def __init__(self, visual_dim, model_dim, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+ self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim)
+
+ def forward(self, x):
+ batch_size, duration, height, width, dim = x.shape
+ x = (
+ x.view(
+ batch_size,
+ duration // self.patch_size[0],
+ self.patch_size[0],
+ height // self.patch_size[1],
+ self.patch_size[1],
+ width // self.patch_size[2],
+ self.patch_size[2],
+ dim,
+ )
+ .permute(0, 1, 3, 5, 2, 4, 6, 7)
+ .flatten(4, 7)
+ )
+ return self.in_layer(x)
+
+
+class Kandinsky5RoPE1D(nn.Module):
+ def __init__(self, dim, max_pos=1024, max_period=10000.0):
+ super().__init__()
+ self.max_period = max_period
+ self.dim = dim
+ self.max_pos = max_pos
+ freq = get_freqs(dim // 2, max_period)
+ pos = torch.arange(max_pos, dtype=freq.dtype)
+ self.register_buffer("args", torch.outer(pos, freq), persistent=False)
+
+ def forward(self, pos):
+ args = self.args[pos]
+ cosine = torch.cos(args)
+ sine = torch.sin(args)
+ rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
+ rope = rope.view(*rope.shape[:-1], 2, 2)
+ return rope.unsqueeze(-4)
+
+
+class Kandinsky5RoPE3D(nn.Module):
+ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0):
+ super().__init__()
+ self.axes_dims = axes_dims
+ self.max_pos = max_pos
+ self.max_period = max_period
+
+ for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)):
+ freq = get_freqs(axes_dim // 2, max_period)
+ pos = torch.arange(ax_max_pos, dtype=freq.dtype)
+ self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False)
+
+ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)):
+ batch_size, duration, height, width = shape
+ args_t = self.args_0[pos[0]] / scale_factor[0]
+ args_h = self.args_1[pos[1]] / scale_factor[1]
+ args_w = self.args_2[pos[2]] / scale_factor[2]
+
+ args = torch.cat(
+ [
+ args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1),
+ args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1),
+ args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1),
+ ],
+ dim=-1,
+ )
+ cosine = torch.cos(args)
+ sine = torch.sin(args)
+ rope = torch.stack([cosine, -sine, sine, cosine], dim=-1)
+ rope = rope.view(*rope.shape[:-1], 2, 2)
+ return rope.unsqueeze(-4)
+
+
+class Kandinsky5Modulation(nn.Module):
+ def __init__(self, time_dim, model_dim, num_params):
+ super().__init__()
+ self.activation = nn.SiLU()
+ self.out_layer = nn.Linear(time_dim, num_params * model_dim)
+ self.out_layer.weight.data.zero_()
+ self.out_layer.bias.data.zero_()
+
+ @torch.autocast(device_type="cuda", dtype=torch.float32)
+ def forward(self, x):
+ return self.out_layer(self.activation(x))
+
+
+class Kandinsky5AttnProcessor:
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None):
+ # query, key, value = self.get_qkv(x)
+ query = attn.to_query(hidden_states)
+
+ if encoder_hidden_states is not None:
+ key = attn.to_key(encoder_hidden_states)
+ value = attn.to_value(encoder_hidden_states)
+
+ shape, cond_shape = query.shape[:-1], key.shape[:-1]
+ query = query.reshape(*shape, attn.num_heads, -1)
+ key = key.reshape(*cond_shape, attn.num_heads, -1)
+ value = value.reshape(*cond_shape, attn.num_heads, -1)
+
+ else:
+ key = attn.to_key(hidden_states)
+ value = attn.to_value(hidden_states)
+
+ shape = query.shape[:-1]
+ query = query.reshape(*shape, attn.num_heads, -1)
+ key = key.reshape(*shape, attn.num_heads, -1)
+ value = value.reshape(*shape, attn.num_heads, -1)
+
+ # query, key = self.norm_qk(query, key)
+ query = attn.query_norm(query.float()).type_as(query)
+ key = attn.key_norm(key.float()).type_as(key)
+
+ def apply_rotary(x, rope):
+ x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32)
+ x_out = (rope * x_).sum(dim=-1)
+ return x_out.reshape(*x.shape).to(torch.bfloat16)
+
+ if rotary_emb is not None:
+ query = apply_rotary(query, rotary_emb).type_as(query)
+ key = apply_rotary(key, rotary_emb).type_as(key)
+
+ if sparse_params is not None:
+ attn_mask = nablaT_v2(
+ query,
+ key,
+ sparse_params["sta_mask"],
+ thr=sparse_params["P"],
+ )
+
+ else:
+ attn_mask = None
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ hidden_states = hidden_states.flatten(-2, -1)
+
+ attn_out = attn.out_layer(hidden_states)
+ return attn_out
+
+
+class Kandinsky5Attention(nn.Module, AttentionModuleMixin):
+ _default_processor_cls = Kandinsky5AttnProcessor
+ _available_processors = [
+ Kandinsky5AttnProcessor,
+ ]
+
+ def __init__(self, num_channels, head_dim, processor=None):
+ super().__init__()
+ assert num_channels % head_dim == 0
+ self.num_heads = num_channels // head_dim
+
+ self.to_query = nn.Linear(num_channels, num_channels, bias=True)
+ self.to_key = nn.Linear(num_channels, num_channels, bias=True)
+ self.to_value = nn.Linear(num_channels, num_channels, bias=True)
+ self.query_norm = nn.RMSNorm(head_dim)
+ self.key_norm = nn.RMSNorm(head_dim)
+
+ self.out_layer = nn.Linear(num_channels, num_channels, bias=True)
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ sparse_params: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ sparse_params=sparse_params,
+ rotary_emb=rotary_emb,
+ **kwargs,
+ )
+
+
+class Kandinsky5FeedForward(nn.Module):
+ def __init__(self, dim, ff_dim):
+ super().__init__()
+ self.in_layer = nn.Linear(dim, ff_dim, bias=False)
+ self.activation = nn.GELU()
+ self.out_layer = nn.Linear(ff_dim, dim, bias=False)
+
+ def forward(self, x):
+ return self.out_layer(self.activation(self.in_layer(x)))
+
+
+class Kandinsky5OutLayer(nn.Module):
+ def __init__(self, model_dim, time_dim, visual_dim, patch_size):
+ super().__init__()
+ self.patch_size = patch_size
+ self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2)
+ self.norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True)
+
+ def forward(self, visual_embed, text_embed, time_embed):
+ shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
+
+ visual_embed = (
+ self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]
+ ).type_as(visual_embed)
+
+ x = self.out_layer(visual_embed)
+
+ batch_size, duration, height, width, _ = x.shape
+ x = (
+ x.view(
+ batch_size,
+ duration,
+ height,
+ width,
+ -1,
+ self.patch_size[0],
+ self.patch_size[1],
+ self.patch_size[2],
+ )
+ .permute(0, 1, 5, 2, 6, 3, 7, 4)
+ .flatten(1, 2)
+ .flatten(2, 3)
+ .flatten(3, 4)
+ )
+ return x
+
+
+class Kandinsky5TransformerEncoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim):
+ super().__init__()
+ self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6)
+
+ self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
+
+ def forward(self, x, time_embed, rope):
+ self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1)
+ shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
+ out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
+ out = self.self_attention(out, rotary_emb=rope)
+ x = (x.float() + gate.float() * out.float()).type_as(x)
+
+ shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
+ out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x)
+ out = self.feed_forward(out)
+ x = (x.float() + gate.float() * out.float()).type_as(x)
+
+ return x
+
+
+class Kandinsky5TransformerDecoderBlock(nn.Module):
+ def __init__(self, model_dim, time_dim, ff_dim, head_dim):
+ super().__init__()
+ self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9)
+
+ self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor())
+
+ self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
+ self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim)
+
+ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params):
+ self_attn_params, cross_attn_params, ff_params = torch.chunk(
+ self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1
+ )
+
+ shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1)
+ visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1)
+ visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ shift, scale, gate = torch.chunk(ff_params, 3, dim=-1)
+ visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(
+ visual_embed
+ )
+ visual_out = self.feed_forward(visual_out)
+ visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed)
+
+ return visual_embed
+
+
+class Kandinsky5Transformer3DModel(
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ CacheMixin,
+ AttentionMixin,
+):
+ """
+ A 3D Diffusion Transformer model for video-like data.
+ """
+
+ _repeated_blocks = [
+ "Kandinsky5TransformerEncoderBlock",
+ "Kandinsky5TransformerDecoderBlock",
+ ]
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_visual_dim=4,
+ in_text_dim=3584,
+ in_text_dim2=768,
+ time_dim=512,
+ out_visual_dim=4,
+ patch_size=(1, 2, 2),
+ model_dim=2048,
+ ff_dim=5120,
+ num_text_blocks=2,
+ num_visual_blocks=32,
+ axes_dims=(16, 24, 24),
+ visual_cond=False,
+ attention_type: str = "regular",
+ attention_causal: bool = None,
+ attention_local: bool = None,
+ attention_glob: bool = None,
+ attention_window: int = None,
+ attention_P: float = None,
+ attention_wT: int = None,
+ attention_wW: int = None,
+ attention_wH: int = None,
+ attention_add_sta: bool = None,
+ attention_method: str = None,
+ ):
+ super().__init__()
+
+ head_dim = sum(axes_dims)
+ self.in_visual_dim = in_visual_dim
+ self.model_dim = model_dim
+ self.patch_size = patch_size
+ self.visual_cond = visual_cond
+ self.attention_type = attention_type
+
+ visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim
+
+ # Initialize embeddings
+ self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim)
+ self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim)
+ self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim)
+ self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size)
+
+ # Initialize positional embeddings
+ self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim)
+ self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims)
+
+ # Initialize transformer blocks
+ self.text_transformer_blocks = nn.ModuleList(
+ [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)]
+ )
+
+ self.visual_transformer_blocks = nn.ModuleList(
+ [
+ Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim)
+ for _ in range(num_visual_blocks)
+ ]
+ )
+
+ # Initialize output layer
+ self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor, # x
+ encoder_hidden_states: torch.Tensor, # text_embed
+ timestep: torch.Tensor, # time
+ pooled_projections: torch.Tensor, # pooled_text_embed
+ visual_rope_pos: Tuple[int, int, int],
+ text_rope_pos: torch.LongTensor,
+ scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0),
+ sparse_params: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Transformer2DModelOutput, torch.FloatTensor]:
+ """
+ Forward pass of the Kandinsky5 3D Transformer.
+
+ Args:
+ hidden_states (`torch.FloatTensor`): Input visual states
+ encoder_hidden_states (`torch.FloatTensor`): Text embeddings
+ timestep (`torch.Tensor` or `float` or `int`): Current timestep
+ pooled_projections (`torch.FloatTensor`): Pooled text embeddings
+ visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE
+ text_rope_pos (`torch.LongTensor`): Position for text RoPE
+ scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE
+ sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention
+ return_dict (`bool`, optional): Whether to return a dictionary
+
+ Returns:
+ [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer
+ """
+ x = hidden_states
+ text_embed = encoder_hidden_states
+ time = timestep
+ pooled_text_embed = pooled_projections
+
+ text_embed = self.text_embeddings(text_embed)
+ time_embed = self.time_embeddings(time)
+ time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed)
+ visual_embed = self.visual_embeddings(x)
+ text_rope = self.text_rope_embeddings(text_rope_pos)
+ text_rope = text_rope.unsqueeze(dim=0)
+
+ for text_transformer_block in self.text_transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ text_embed = self._gradient_checkpointing_func(
+ text_transformer_block, text_embed, time_embed, text_rope
+ )
+ else:
+ text_embed = text_transformer_block(text_embed, time_embed, text_rope)
+
+ visual_shape = visual_embed.shape[:-1]
+ visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor)
+ to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False
+ visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal)
+
+ for visual_transformer_block in self.visual_transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ visual_embed = self._gradient_checkpointing_func(
+ visual_transformer_block,
+ visual_embed,
+ text_embed,
+ time_embed,
+ visual_rope,
+ sparse_params,
+ )
+ else:
+ visual_embed = visual_transformer_block(
+ visual_embed, text_embed, time_embed, visual_rope, sparse_params
+ )
+
+ visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal)
+ x = self.out_layer(visual_embed, text_embed, time_embed)
+
+ if not return_dict:
+ return x
+
+ return Transformer2DModelOutput(sample=x)
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
index 79149fb760..685c73c07c 100644
--- a/src/diffusers/models/transformers/transformer_ltx.py
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -24,6 +24,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -51,6 +52,7 @@ class LTXVideoAttnProcessor:
"""
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
@@ -100,6 +102,7 @@ class LTXVideoAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
@@ -350,7 +353,9 @@ class LTXVideoTransformerBlock(nn.Module):
norm_hidden_states = self.norm1(hidden_states)
num_ada_params = self.scale_shift_table.shape[0]
- ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
+ ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
+ batch_size, temb.size(1), num_ada_params, -1
+ )
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -409,6 +414,18 @@ class LTXVideoTransformer3DModel(
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
_repeated_blocks = ["LTXVideoTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_attention_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=3, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_prx.py b/src/diffusers/models/transformers/transformer_prx.py
new file mode 100644
index 0000000000..9b2664b9cb
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_prx.py
@@ -0,0 +1,770 @@
+# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn.functional import fold, unfold
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...utils import logging
+from ..attention import AttentionMixin, AttentionModuleMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..embeddings import get_timestep_embedding
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import RMSNorm
+
+
+logger = logging.get_logger(__name__)
+
+
+def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
+ r"""
+ Generates 2D patch coordinate indices for a batch of images.
+
+ Args:
+ batch_size (`int`):
+ Number of images in the batch.
+ height (`int`):
+ Height of the input images (in pixels).
+ width (`int`):
+ Width of the input images (in pixels).
+ patch_size (`int`):
+ Size of the square patches that the image is divided into.
+ device (`torch.device`):
+ The device on which to create the tensor.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
+ image grid.
+ """
+
+ img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
+ img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
+ img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
+ return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
+
+
+def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ r"""
+ Applies rotary positional embeddings (RoPE) to a query tensor.
+
+ Args:
+ xq (`torch.Tensor`):
+ Input tensor of shape `(..., dim)` representing the queries.
+ freqs_cis (`torch.Tensor`):
+ Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of the same shape as `xq` with rotary embeddings applied.
+ """
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
+ # Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
+ freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
+ return xq_out.reshape(*xq.shape).type_as(xq)
+
+
+class PRXAttnProcessor2_0:
+ r"""
+ Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
+ backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
+ raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: "PRXAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ """
+ Apply PRX attention using PRXAttention module.
+
+ Args:
+ attn: PRXAttention module containing projection layers
+ hidden_states: Image tokens [B, L_img, D]
+ encoder_hidden_states: Text tokens [B, L_txt, D]
+ attention_mask: Boolean mask for text tokens [B, L_txt]
+ image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
+ """
+
+ if encoder_hidden_states is None:
+ raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
+
+ # Project image tokens to Q, K, V
+ img_qkv = attn.img_qkv_proj(hidden_states)
+ B, L_img, _ = img_qkv.shape
+ img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
+ img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
+ img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
+
+ # Apply QK normalization to image tokens
+ img_q = attn.norm_q(img_q)
+ img_k = attn.norm_k(img_k)
+
+ # Project text tokens to K, V
+ txt_kv = attn.txt_kv_proj(encoder_hidden_states)
+ B, L_txt, _ = txt_kv.shape
+ txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
+ txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
+ txt_k, txt_v = txt_kv[0], txt_kv[1]
+
+ # Apply K normalization to text tokens
+ txt_k = attn.norm_added_k(txt_k)
+
+ # Apply RoPE to image queries and keys
+ if image_rotary_emb is not None:
+ img_q = apply_rope(img_q, image_rotary_emb)
+ img_k = apply_rope(img_k, image_rotary_emb)
+
+ # Concatenate text and image keys/values
+ k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
+ v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
+
+ # Build attention mask if provided
+ attn_mask_tensor = None
+ if attention_mask is not None:
+ bs, _, l_img, _ = img_q.shape
+ l_txt = txt_k.shape[2]
+
+ if attention_mask.dim() != 2:
+ raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
+ if attention_mask.shape[-1] != l_txt:
+ raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
+
+ device = img_q.device
+ ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
+ attention_mask = attention_mask.to(device=device, dtype=torch.bool)
+ joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
+ attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
+
+ # Apply attention using dispatch_attention_fn for backend support
+ # Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
+ query = img_q.transpose(1, 2) # [B, L_img, H, D]
+ key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
+ value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
+
+ attn_output = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attn_mask_tensor,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+
+ # Reshape from [B, L_img, H, D] to [B, L_img, H*D]
+ batch_size, seq_len, num_heads, head_dim = attn_output.shape
+ attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
+
+ # Apply output projection
+ attn_output = attn.to_out[0](attn_output)
+ if len(attn.to_out) > 1:
+ attn_output = attn.to_out[1](attn_output) # dropout if present
+
+ return attn_output
+
+
+class PRXAttention(nn.Module, AttentionModuleMixin):
+ r"""
+ PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
+ PRX's architecture.
+ """
+
+ _default_processor_cls = PRXAttnProcessor2_0
+ _available_processors = [PRXAttnProcessor2_0]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ bias: bool = False,
+ out_bias: bool = False,
+ eps: float = 1e-6,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.heads = heads
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.query_dim = query_dim
+
+ self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
+
+ self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+ self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+
+ self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
+ self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
+ self.to_out.append(nn.Dropout(0.0))
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ **kwargs,
+ )
+
+
+# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
+class PRXEmbedND(nn.Module):
+ r"""
+ N-dimensional rotary positional embedding.
+
+ This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
+ dimension. The embeddings are combined and returned as a single tensor
+
+ Args:
+ dim (int):
+ Base embedding dimension (must be even).
+ theta (int):
+ Scaling factor that controls the frequency spectrum of the rotary embeddings.
+ axes_dim (list[int]):
+ List of embedding dimensions for each axis (each must be even).
+ """
+
+ def __init__(self, dim: int, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
+ assert dim % 2 == 0
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
+ omega = 1.0 / (theta**scale)
+ out = pos.unsqueeze(-1) * omega.unsqueeze(0)
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
+ # Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
+ # out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
+ out = out.reshape(*out.shape[:-1], 2, 2)
+ return out.float()
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(1)
+
+
+class MLPEmbedder(nn.Module):
+ r"""
+ A simple 2-layer MLP used for embedding inputs.
+
+ Args:
+ in_dim (`int`):
+ Dimensionality of the input features.
+ hidden_dim (`int`):
+ Dimensionality of the hidden and output embedding space.
+
+ Returns:
+ `torch.Tensor`:
+ Tensor of shape `(..., hidden_dim)` containing the embedded representations.
+ """
+
+ def __init__(self, in_dim: int, hidden_dim: int):
+ super().__init__()
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
+ self.silu = nn.SiLU()
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.out_layer(self.silu(self.in_layer(x)))
+
+
+class Modulation(nn.Module):
+ r"""
+ Modulation network that generates scale, shift, and gating parameters.
+
+ Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
+ two tuples `(shift, scale, gate)`.
+
+ Args:
+ dim (`int`):
+ Dimensionality of the input vector. The output will have `6 * dim` features internally.
+
+ Returns:
+ ((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
+ Two tuples `(shift, scale, gate)`.
+ """
+
+ def __init__(self, dim: int):
+ super().__init__()
+ self.lin = nn.Linear(dim, 6 * dim, bias=True)
+ nn.init.constant_(self.lin.weight, 0)
+ nn.init.constant_(self.lin.bias, 0)
+
+ def forward(
+ self, vec: torch.Tensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
+ return tuple(out[:3]), tuple(out[3:])
+
+
+class PRXBlock(nn.Module):
+ r"""
+ Multimodal transformer block with text–image cross-attention, modulation, and MLP.
+
+ Args:
+ hidden_size (`int`):
+ Dimension of the hidden representations.
+ num_heads (`int`):
+ Number of attention heads.
+ mlp_ratio (`float`, *optional*, defaults to 4.0):
+ Expansion ratio for the hidden dimension inside the MLP.
+ qk_scale (`float`, *optional*):
+ Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
+
+ Attributes:
+ img_pre_norm (`nn.LayerNorm`):
+ Pre-normalization applied to image tokens before attention.
+ attention (`PRXAttention`):
+ Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
+ image and text tokens.
+ post_attention_layernorm (`nn.LayerNorm`):
+ Normalization applied after attention.
+ gate_proj / up_proj / down_proj (`nn.Linear`):
+ Feedforward layers forming the gated MLP.
+ mlp_act (`nn.GELU`):
+ Nonlinear activation used in the MLP.
+ modulation (`Modulation`):
+ Produces scale/shift/gating parameters for modulated layers.
+
+ Methods:
+ The forward method performs cross-attention and the MLP with modulation.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: Optional[float] = None,
+ ):
+ super().__init__()
+
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ self.head_dim = hidden_size // num_heads
+ self.scale = qk_scale or self.head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.hidden_size = hidden_size
+
+ # Pre-attention normalization for image tokens
+ self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+
+ # PRXAttention module with built-in projections and norms
+ self.attention = PRXAttention(
+ query_dim=hidden_size,
+ heads=num_heads,
+ dim_head=self.head_dim,
+ bias=False,
+ out_bias=False,
+ eps=1e-6,
+ processor=PRXAttnProcessor2_0(),
+ )
+
+ # mlp
+ self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
+ self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
+ self.mlp_act = nn.GELU(approximate="tanh")
+
+ self.modulation = Modulation(hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs: Dict[str, Any],
+ ) -> torch.Tensor:
+ r"""
+ Runs modulation-gated cross-attention and MLP, with residual connections.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Image tokens of shape `(B, L_img, hidden_size)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Text tokens of shape `(B, L_txt, hidden_size)`.
+ temb (`torch.Tensor`):
+ Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
+ broadcastable).
+ image_rotary_emb (`torch.Tensor`):
+ Rotary positional embeddings applied inside attention.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
+ **kwargs:
+ Additional keyword arguments for API compatibility.
+
+ Returns:
+ `torch.Tensor`:
+ Updated image tokens of shape `(B, L_img, hidden_size)`.
+ """
+
+ mod_attn, mod_mlp = self.modulation(temb)
+ attn_shift, attn_scale, attn_gate = mod_attn
+ mlp_shift, mlp_scale, mlp_gate = mod_mlp
+
+ hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
+
+ attn_out = self.attention(
+ hidden_states=hidden_states_mod,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ hidden_states = hidden_states + attn_gate * attn_out
+
+ x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
+ hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
+ return hidden_states
+
+
+class FinalLayer(nn.Module):
+ r"""
+ Final projection layer with adaptive LayerNorm modulation.
+
+ This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
+ outputs.
+
+ Args:
+ hidden_size (`int`):
+ Dimensionality of the input tokens.
+ patch_size (`int`):
+ Size of the square image patches.
+ out_channels (`int`):
+ Number of output channels per pixel (e.g. RGB = 3).
+
+ Forward Inputs:
+ x (`torch.Tensor`):
+ Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
+ vec (`torch.Tensor`):
+ Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
+ LayerNorm.
+
+ Returns:
+ `torch.Tensor`:
+ Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
+ """
+
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
+ super().__init__()
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
+
+ def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
+
+
+def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
+ r"""
+ Flattens an image tensor into a sequence of non-overlapping patches.
+
+ Args:
+ img (`torch.Tensor`):
+ Input image tensor of shape `(B, C, H, W)`.
+ patch_size (`int`):
+ Size of each square patch. Must evenly divide both `H` and `W`.
+
+ Returns:
+ `torch.Tensor`:
+ Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
+ // patch_size)` is the number of patches.
+ """
+ return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
+
+
+def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
+ r"""
+ Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
+
+ Args:
+ seq (`torch.Tensor`):
+ Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
+ patch_size)`.
+ patch_size (`int`):
+ Size of each square patch.
+ shape (`tuple` or `torch.Tensor`):
+ The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
+ height and width.
+
+ Returns:
+ `torch.Tensor`:
+ Reconstructed image tensor of shape `(B, C, H, W)`.
+ """
+ if isinstance(shape, tuple):
+ shape = shape[-2:]
+ elif isinstance(shape, torch.Tensor):
+ shape = (int(shape[0]), int(shape[1]))
+ else:
+ raise NotImplementedError(f"shape type {type(shape)} not supported")
+ return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
+
+
+class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
+ r"""
+ Transformer-based 2D model for text to image generation.
+
+ Args:
+ in_channels (`int`, *optional*, defaults to 16):
+ Number of input channels in the latent image.
+ patch_size (`int`, *optional*, defaults to 2):
+ Size of the square patches used to flatten the input image.
+ context_in_dim (`int`, *optional*, defaults to 2304):
+ Dimensionality of the text conditioning input.
+ hidden_size (`int`, *optional*, defaults to 1792):
+ Dimension of the hidden representation.
+ mlp_ratio (`float`, *optional*, defaults to 3.5):
+ Expansion ratio for the hidden dimension inside MLP blocks.
+ num_heads (`int`, *optional*, defaults to 28):
+ Number of attention heads.
+ depth (`int`, *optional*, defaults to 16):
+ Number of transformer blocks.
+ axes_dim (`list[int]`, *optional*):
+ List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
+ theta (`int`, *optional*, defaults to 10000):
+ Frequency scaling factor for rotary embeddings.
+ time_factor (`float`, *optional*, defaults to 1000.0):
+ Scaling factor applied in timestep embeddings.
+ time_max_period (`int`, *optional*, defaults to 10000):
+ Maximum frequency period for timestep embeddings.
+
+ Attributes:
+ pe_embedder (`EmbedND`):
+ Multi-axis rotary embedding generator for positional encodings.
+ img_in (`nn.Linear`):
+ Projection layer for image patch tokens.
+ time_in (`MLPEmbedder`):
+ Embedding layer for timestep embeddings.
+ txt_in (`nn.Linear`):
+ Projection layer for text conditioning.
+ blocks (`nn.ModuleList`):
+ Stack of transformer blocks (`PRXBlock`).
+ final_layer (`LastLayer`):
+ Projection layer mapping hidden tokens back to patch outputs.
+
+ Methods:
+ attn_processors:
+ Returns a dictionary of all attention processors in the model.
+ set_attn_processor(processor):
+ Replaces attention processors across all attention layers.
+ process_inputs(image_latent, txt):
+ Converts inputs into patch tokens, encodes text, and produces positional encodings.
+ compute_timestep_embedding(timestep, dtype):
+ Creates a timestep embedding of dimension 256, scaled and projected.
+ forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
+ **block_kwargs):
+ Runs the sequence of transformer blocks over image and text tokens.
+ forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
+ attention_kwargs=None, return_dict=True):
+ Full forward pass from latent input to reconstructed output image.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
+ - `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
+ """
+
+ config_name = "config.json"
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ patch_size: int = 2,
+ context_in_dim: int = 2304,
+ hidden_size: int = 1792,
+ mlp_ratio: float = 3.5,
+ num_heads: int = 28,
+ depth: int = 16,
+ axes_dim: list = None,
+ theta: int = 10000,
+ time_factor: float = 1000.0,
+ time_max_period: int = 10000,
+ ):
+ super().__init__()
+
+ if axes_dim is None:
+ axes_dim = [32, 32]
+
+ # Store parameters directly
+ self.in_channels = in_channels
+ self.patch_size = patch_size
+ self.out_channels = self.in_channels * self.patch_size**2
+
+ self.time_factor = time_factor
+ self.time_max_period = time_max_period
+
+ if hidden_size % num_heads != 0:
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
+
+ pe_dim = hidden_size // num_heads
+
+ if sum(axes_dim) != pe_dim:
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
+
+ self.hidden_size = hidden_size
+ self.num_heads = num_heads
+ self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
+ self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
+
+ self.blocks = nn.ModuleList(
+ [
+ PRXBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=mlp_ratio,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
+
+ self.gradient_checkpointing = False
+
+ def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
+ return self.time_in(
+ get_timestep_embedding(
+ timesteps=timestep,
+ embedding_dim=256,
+ max_period=self.time_max_period,
+ scale=self.time_factor,
+ flip_sin_to_cos=True, # Match original cos, sin order
+ ).to(dtype)
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ r"""
+ Forward pass of the PRXTransformer2DModel.
+
+ The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
+ transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input latent image tensor of shape `(B, C, H, W)`.
+ timestep (`torch.Tensor`):
+ Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
+ encoder_hidden_states (`torch.Tensor`):
+ Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
+ attention_mask (`torch.Tensor`, *optional*):
+ Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
+ attention_kwargs (`dict`, *optional*):
+ Additional arguments passed to attention layers.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a `Transformer2DModelOutput` or a tuple.
+
+ Returns:
+ `Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
+
+ - `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
+ """
+ # Process text conditioning
+ txt = self.txt_in(encoder_hidden_states)
+
+ # Convert image to sequence and embed
+ img = img2seq(hidden_states, self.patch_size)
+ img = self.img_in(img)
+
+ # Generate positional embeddings
+ bs, _, h, w = hidden_states.shape
+ img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
+ pe = self.pe_embedder(img_ids)
+
+ # Compute time embedding
+ vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
+
+ # Apply transformer blocks
+ for block in self.blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ img = self._gradient_checkpointing_func(
+ block.__call__,
+ img,
+ txt,
+ vec,
+ pe,
+ attention_mask,
+ )
+ else:
+ img = block(
+ hidden_states=img,
+ encoder_hidden_states=txt,
+ temb=vec,
+ image_rotary_emb=pe,
+ attention_mask=attention_mask,
+ )
+
+ # Final layer and convert back to image
+ img = self.final_layer(img, vec)
+ output = seq2img(img, self.patch_size, hidden_states.shape)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py
index 846add8906..c0fa031b9f 100644
--- a/src/diffusers/models/transformers/transformer_qwenimage.py
+++ b/src/diffusers/models/transformers/transformer_qwenimage.py
@@ -25,6 +25,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..attention_processor import Attention
@@ -179,7 +180,6 @@ class QwenEmbedRope(nn.Module):
],
dim=1,
)
- self.rope_cache = {}
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope
@@ -194,10 +194,20 @@ class QwenEmbedRope(nn.Module):
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
- def forward(self, video_fhw, txt_seq_lens, device):
+ def forward(
+ self,
+ video_fhw: Union[Tuple[int, int, int], List[Tuple[int, int, int]]],
+ txt_seq_lens: List[int],
+ device: torch.device,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
"""
- Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
- txt_length: [bs] a list of 1 integers representing the length of the text
+ Args:
+ video_fhw (`Tuple[int, int, int]` or `List[Tuple[int, int, int]]`):
+ A list of 3 integers [frame, height, width] representing the shape of the video.
+ txt_seq_lens (`List[int]`):
+ A list of integers of length batch_size representing the length of each text prompt.
+ device: (`torch.device`):
+ The device on which to perform the RoPE computation.
"""
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
@@ -212,14 +222,8 @@ class QwenEmbedRope(nn.Module):
max_vid_index = 0
for idx, fhw in enumerate(video_fhw):
frame, height, width = fhw
- rope_key = f"{idx}_{height}_{width}"
-
- if not torch.compiler.is_compiling():
- if rope_key not in self.rope_cache:
- self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
- video_freq = self.rope_cache[rope_key]
- else:
- video_freq = self._compute_video_freqs(frame, height, width, idx)
+ # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
+ video_freq = self._compute_video_freqs(frame, height, width, idx)
video_freq = video_freq.to(device)
vid_freqs.append(video_freq)
@@ -234,8 +238,8 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
- @functools.lru_cache(maxsize=None)
- def _compute_video_freqs(self, frame, height, width, idx=0):
+ @functools.lru_cache(maxsize=128)
+ def _compute_video_freqs(self, frame: int, height: int, width: int, idx: int = 0) -> torch.Tensor:
seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -261,6 +265,7 @@ class QwenDoubleStreamAttnProcessor2_0:
"""
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -334,6 +339,7 @@ class QwenDoubleStreamAttnProcessor2_0:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
# Reshape back
@@ -502,6 +508,18 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py
new file mode 100644
index 0000000000..aaf96175c0
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_sana_video.py
@@ -0,0 +1,703 @@
+# Copyright 2025 The HuggingFace Team and SANA-Video Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import AttentionMixin
+from ..attention_dispatch import dispatch_attention_fn
+from ..attention_processor import Attention
+from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import AdaLayerNormSingle, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class GLUMBTempConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 4,
+ norm_type: Optional[str] = None,
+ residual_connection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_channels = int(expand_ratio * in_channels)
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ self.nonlinearity = nn.SiLU()
+ self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
+ self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
+ self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
+
+ self.norm = None
+ if norm_type == "rms_norm":
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
+
+ self.conv_temp = nn.Conv2d(
+ out_channels, out_channels, kernel_size=(3, 1), stride=1, padding=(1, 0), bias=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.residual_connection:
+ residual = hidden_states
+ batch_size, num_frames, height, width, num_channels = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size * num_frames, height, width, num_channels).permute(0, 3, 1, 2)
+
+ hidden_states = self.conv_inverted(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv_depth(hidden_states)
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
+ hidden_states = hidden_states * self.nonlinearity(gate)
+
+ hidden_states = self.conv_point(hidden_states)
+
+ # Temporal aggregation
+ hidden_states_temporal = hidden_states.view(batch_size, num_frames, num_channels, height * width).permute(
+ 0, 2, 1, 3
+ )
+ hidden_states = hidden_states_temporal + self.conv_temp(hidden_states_temporal)
+ hidden_states = hidden_states.permute(0, 2, 3, 1).view(batch_size, num_frames, height, width, num_channels)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class SanaLinearAttnProcessor3_0:
+ r"""
+ Processor for implementing scaled dot-product linear attention.
+ """
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ original_dtype = hidden_states.dtype
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+ # B,N,H,C
+
+ query = F.relu(query)
+ key = F.relu(key)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query_rotate = apply_rotary_emb(query, *rotary_emb)
+ key_rotate = apply_rotary_emb(key, *rotary_emb)
+
+ # B,H,C,N
+ query = query.permute(0, 2, 3, 1)
+ key = key.permute(0, 2, 3, 1)
+ query_rotate = query_rotate.permute(0, 2, 3, 1)
+ key_rotate = key_rotate.permute(0, 2, 3, 1)
+ value = value.permute(0, 2, 3, 1)
+
+ query_rotate, key_rotate, value = query_rotate.float(), key_rotate.float(), value.float()
+
+ z = 1 / (key.sum(dim=-1, keepdim=True).transpose(-2, -1) @ query + 1e-15)
+
+ scores = torch.matmul(value, key_rotate.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query_rotate)
+
+ hidden_states = hidden_states * z
+ # B,H,C,N
+ hidden_states = hidden_states.flatten(1, 2).transpose(1, 2)
+ hidden_states = hidden_states.to(original_dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed
+class WanRotaryPosEmbed(nn.Module):
+ def __init__(
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
+ for dim in [t_dim, h_dim, w_dim]:
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
+
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+
+
+# Copied from diffusers.models.transformers.sana_transformer.SanaModulatedNorm
+class SanaModulatedNorm(nn.Module):
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
+ ) -> torch.Tensor:
+ hidden_states = self.norm(hidden_states)
+ shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+ return hidden_states
+
+
+class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
+ def __init__(self, embedding_dim):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
+
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
+
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
+ timesteps_proj = self.time_proj(timestep)
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
+
+ guidance_proj = self.guidance_condition_proj(guidance)
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
+ conditioning = timesteps_emb + guidance_emb
+
+ return self.linear(self.silu(conditioning)), conditioning
+
+
+class SanaAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ _attention_backend = None
+ _parallel_config = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, attn.heads, head_dim)
+ value = value.view(batch_size, -1, attn.heads, head_dim)
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+
+class SanaVideoTransformerBlock(nn.Module):
+ r"""
+ Transformer block introduced in [Sana-Video](https://huggingface.co/papers/2509.24695).
+ """
+
+ def __init__(
+ self,
+ dim: int = 2240,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ dropout: float = 0.0,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ attention_bias: bool = True,
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ attention_out_bias: bool = True,
+ mlp_ratio: float = 3.0,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ # 1. Self Attention
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=norm_eps)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ kv_heads=num_attention_heads if qk_norm is not None else None,
+ qk_norm=qk_norm,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ processor=SanaLinearAttnProcessor3_0(),
+ )
+
+ # 2. Cross Attention
+ if cross_attention_dim is not None:
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
+ self.attn2 = Attention(
+ query_dim=dim,
+ qk_norm=qk_norm,
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_cross_attention_heads,
+ dim_head=cross_attention_head_dim,
+ dropout=dropout,
+ bias=True,
+ out_bias=attention_out_bias,
+ processor=SanaAttnProcessor2_0(),
+ )
+
+ # 3. Feed-forward
+ self.ff = GLUMBTempConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ frames: int = None,
+ height: int = None,
+ width: int = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ # 1. Modulation
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
+ ).chunk(6, dim=1)
+
+ # 2. Self Attention
+ norm_hidden_states = self.norm1(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+ norm_hidden_states = norm_hidden_states.to(hidden_states.dtype)
+
+ attn_output = self.attn1(norm_hidden_states, rotary_emb=rotary_emb)
+ hidden_states = hidden_states + gate_msa * attn_output
+
+ # 3. Cross Attention
+ if self.attn2 is not None:
+ attn_output = self.attn2(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # 4. Feed-forward
+ norm_hidden_states = self.norm2(hidden_states)
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ norm_hidden_states = norm_hidden_states.unflatten(1, (frames, height, width))
+ ff_output = self.ff(norm_hidden_states)
+ ff_output = ff_output.flatten(1, 3)
+ hidden_states = hidden_states + gate_mlp * ff_output
+
+ return hidden_states
+
+
+class SanaVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, AttentionMixin):
+ r"""
+ A 3D Transformer model introduced in [Sana-Video](https://huggingface.co/papers/2509.24695) family of models.
+
+ Args:
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, *optional*, defaults to `16`):
+ The number of channels in the output.
+ num_attention_heads (`int`, defaults to `20`):
+ The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, defaults to `112`):
+ The number of channels in each head.
+ num_layers (`int`, defaults to `20`):
+ The number of layers of Transformer blocks to use.
+ num_cross_attention_heads (`int`, *optional*, defaults to `20`):
+ The number of heads to use for cross-attention.
+ cross_attention_head_dim (`int`, *optional*, defaults to `112`):
+ The number of channels in each head for cross-attention.
+ cross_attention_dim (`int`, *optional*, defaults to `2240`):
+ The number of channels in the cross-attention output.
+ caption_channels (`int`, defaults to `2304`):
+ The number of channels in the caption embeddings.
+ mlp_ratio (`float`, defaults to `2.5`):
+ The expansion ratio to use in the GLUMBConv layer.
+ dropout (`float`, defaults to `0.0`):
+ The dropout probability.
+ attention_bias (`bool`, defaults to `False`):
+ Whether to use bias in the attention layer.
+ sample_size (`int`, defaults to `32`):
+ The base size of the input latent.
+ patch_size (`int`, defaults to `1`):
+ The size of the patches to use in the patch embedding layer.
+ norm_elementwise_affine (`bool`, defaults to `False`):
+ Whether to use elementwise affinity in the normalization layer.
+ norm_eps (`float`, defaults to `1e-6`):
+ The epsilon value for the normalization layer.
+ qk_norm (`str`, *optional*, defaults to `None`):
+ The normalization to use for the query and key.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["SanaVideoTransformerBlock", "SanaModulatedNorm"]
+ _skip_layerwise_casting_patterns = ["patch_embedding", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 16,
+ out_channels: Optional[int] = 16,
+ num_attention_heads: int = 20,
+ attention_head_dim: int = 112,
+ num_layers: int = 20,
+ num_cross_attention_heads: Optional[int] = 20,
+ cross_attention_head_dim: Optional[int] = 112,
+ cross_attention_dim: Optional[int] = 2240,
+ caption_channels: int = 2304,
+ mlp_ratio: float = 2.5,
+ dropout: float = 0.0,
+ attention_bias: bool = False,
+ sample_size: int = 30,
+ patch_size: Tuple[int, int, int] = (1, 2, 2),
+ norm_elementwise_affine: bool = False,
+ norm_eps: float = 1e-6,
+ interpolation_scale: Optional[int] = None,
+ guidance_embeds: bool = False,
+ guidance_embeds_scale: float = 0.1,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ rope_max_seq_len: int = 1024,
+ ) -> None:
+ super().__init__()
+
+ out_channels = out_channels or in_channels
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Patch & position embedding
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Additional condition embeddings
+ if guidance_embeds:
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
+ else:
+ self.time_embed = AdaLayerNormSingle(inner_dim)
+
+ self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
+ self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
+
+ # 3. Transformer blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ SanaVideoTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ num_cross_attention_heads=num_cross_attention_heads,
+ cross_attention_head_dim=cross_attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ norm_elementwise_affine=norm_elementwise_affine,
+ norm_eps=norm_eps,
+ mlp_ratio=mlp_ratio,
+ qk_norm=qk_norm,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output blocks
+ self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out = nn.Linear(inner_dim, math.prod(patch_size) * out_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ guidance: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None,
+ return_dict: bool = True,
+ ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ if guidance is not None:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
+ )
+ else:
+ timestep, embedded_timestep = self.time_embed(
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
+ )
+
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+
+ encoder_hidden_states = self.caption_norm(encoder_hidden_states)
+
+ # 2. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ else:
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = block(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ timestep,
+ post_patch_num_frames,
+ post_patch_height,
+ post_patch_width,
+ rotary_emb,
+ )
+ if controlnet_block_samples is not None and 0 < index_block <= len(controlnet_block_samples):
+ hidden_states = hidden_states + controlnet_block_samples[index_block - 1]
+
+ # 3. Normalization
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ # 5. Unpatchify
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py
index edf77a7df7..762d89c303 100644
--- a/src/diffusers/models/transformers/transformer_sd3.py
+++ b/src/diffusers/models/transformers/transformer_sd3.py
@@ -280,11 +280,7 @@ class SD3Transformer2DModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -304,11 +300,7 @@ class SD3Transformer2DModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py
index 358759164b..6b600aa224 100644
--- a/src/diffusers/models/transformers/transformer_skyreels_v2.py
+++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py
@@ -73,6 +73,7 @@ def _get_added_kv_projections(attn: "SkyReelsV2Attention", encoder_hidden_states
class SkyReelsV2AttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -139,6 +140,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -151,6 +153,7 @@ class SkyReelsV2AttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index 968a0369c2..dd75fb124f 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -23,6 +23,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
+from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
@@ -66,6 +67,7 @@ def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: t
class WanAttnProcessor:
_attention_backend = None
+ _parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
@@ -132,6 +134,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
@@ -144,6 +147,7 @@ class WanAttnProcessor:
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
+ parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
@@ -539,6 +543,19 @@ class WanTransformer3DModel(
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["WanTransformerBlock"]
+ _cp_plan = {
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "blocks.*": {
+ "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
@register_to_config
def __init__(
@@ -665,12 +682,12 @@ class WanTransformer3DModel(
# 5. Output norm, projection & unpatchify
if temb.ndim == 3:
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
- shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift = shift.squeeze(2)
scale = scale.squeeze(2)
else:
# batch_size, inner_dim
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py
index e5a9c7e0a6..30c38c244a 100644
--- a/src/diffusers/models/transformers/transformer_wan_vace.py
+++ b/src/diffusers/models/transformers/transformer_wan_vace.py
@@ -103,7 +103,7 @@ class WanVACETransformerBlock(nn.Module):
control_hidden_states = control_hidden_states + hidden_states
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb.float()
+ self.scale_shift_table.to(temb.device) + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
@@ -361,7 +361,7 @@ class WanVACETransformer3DModel(
hidden_states = hidden_states + control_hint * scale
# 6. Output norm, projection & unpatchify
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/unets/unet_1d.py b/src/diffusers/models/unets/unet_1d.py
index 4f57f3349b..4c4c528a59 100644
--- a/src/diffusers/models/unets/unet_1d.py
+++ b/src/diffusers/models/unets/unet_1d.py
@@ -82,6 +82,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
out_channels: int = 2,
extra_in_channels: int = 0,
time_embedding_type: str = "fourier",
+ time_embedding_dim: Optional[int] = None,
flip_sin_to_cos: bool = True,
use_timestep_embedding: bool = False,
freq_shift: float = 0.0,
@@ -100,15 +101,23 @@ class UNet1DModel(ModelMixin, ConfigMixin):
# time
if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
- embedding_size=8, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ embedding_size=time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
- timestep_input_dim = 2 * block_out_channels[0]
+ timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(
block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift
)
timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
if use_timestep_embedding:
time_embed_dim = block_out_channels[0] * 4
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 736deb28c3..f04d3dfa01 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -16,7 +16,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
@@ -872,11 +871,7 @@ class UNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -895,11 +890,7 @@ class UNet2DConditionModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_3d_condition.py b/src/diffusers/models/unets/unet_3d_condition.py
index bd67ea414a..6a119185b8 100644
--- a/src/diffusers/models/unets/unet_3d_condition.py
+++ b/src/diffusers/models/unets/unet_3d_condition.py
@@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
@@ -508,11 +507,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -532,11 +527,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_i2vgen_xl.py b/src/diffusers/models/unets/unet_i2vgen_xl.py
index 8449bf894c..3dba8edca7 100644
--- a/src/diffusers/models/unets/unet_i2vgen_xl.py
+++ b/src/diffusers/models/unets/unet_i2vgen_xl.py
@@ -16,7 +16,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
@@ -472,11 +471,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -496,11 +491,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/models/unets/unet_kandinsky3.py b/src/diffusers/models/unets/unet_kandinsky3.py
index 423669a22f..27241ce2e6 100644
--- a/src/diffusers/models/unets/unet_kandinsky3.py
+++ b/src/diffusers/models/unets/unet_kandinsky3.py
@@ -16,7 +16,6 @@ from dataclasses import dataclass
from typing import Dict, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
diff --git a/src/diffusers/models/unets/unet_motion_model.py b/src/diffusers/models/unets/unet_motion_model.py
index 0a112b5249..18d5eb917f 100644
--- a/src/diffusers/models/unets/unet_motion_model.py
+++ b/src/diffusers/models/unets/unet_motion_model.py
@@ -18,7 +18,6 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
@@ -1911,11 +1910,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1935,11 +1930,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
index 65c22b349b..86ed735134 100644
--- a/src/diffusers/modular_pipelines/__init__.py
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -46,12 +46,19 @@ else:
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
- _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
+ _import_structure["flux"] = [
+ "FluxAutoBlocks",
+ "FluxModularPipeline",
+ "FluxKontextAutoBlocks",
+ "FluxKontextModularPipeline",
+ ]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
"QwenImageEditModularPipeline",
"QwenImageEditAutoBlocks",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageEditPlusAutoBlocks",
]
_import_structure["components_manager"] = ["ComponentsManager"]
@@ -63,7 +70,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
- from .flux import FluxAutoBlocks, FluxModularPipeline
+ from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
@@ -78,6 +85,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
+ QwenImageEditPlusAutoBlocks,
+ QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py
index f48a227e2e..cb7e8fb736 100644
--- a/src/diffusers/modular_pipelines/components_manager.py
+++ b/src/diffusers/modular_pipelines/components_manager.py
@@ -25,6 +25,7 @@ from ..utils import (
is_accelerate_available,
logging,
)
+from ..utils.torch_utils import get_device
if is_accelerate_available():
@@ -161,7 +162,13 @@ class AutoOffloadStrategy:
current_module_size = model.get_memory_footprint()
- mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
+ device_type = execution_device.type
+ device_module = getattr(torch, device_type, torch.cuda)
+ try:
+ mem_on_device = device_module.mem_get_info(execution_device.index)[0]
+ except AttributeError:
+ raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.")
+
mem_on_device = mem_on_device - self.memory_reserve_margin
if current_module_size < mem_on_device:
return []
@@ -283,11 +290,7 @@ class ComponentsManager:
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
management, and component organization.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Example:
```python
@@ -301,7 +304,7 @@ class ComponentsManager:
cm.add("vae", vae_model, collection="sdxl")
# Enable auto offloading
- cm.enable_auto_cpu_offload(device="cuda")
+ cm.enable_auto_cpu_offload()
# Retrieve components
unet = cm.get_one(name="unet", collection="sdxl")
@@ -490,6 +493,8 @@ class ComponentsManager:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ if torch.xpu.is_available():
+ torch.xpu.empty_cache()
# YiYi TODO: rename to search_components for now, may remove this method
def search_components(
@@ -678,7 +683,7 @@ class ComponentsManager:
return get_return_dict(matches, return_dict_with_names)
- def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
+ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, memory_reserve_margin="3GB"):
"""
Enable automatic CPU offloading for all components.
@@ -698,12 +703,16 @@ class ComponentsManager:
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
+ # TODO: add a warning if mem_get_info isn't available on `device`.
+
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
remove_hook_from_module(component, recurse=True)
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
+ if device is None:
+ device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py
index 2891edf790..ec00986611 100644
--- a/src/diffusers/modular_pipelines/flux/__init__.py
+++ b/src/diffusers/modular_pipelines/flux/__init__.py
@@ -25,14 +25,18 @@ else:
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
+ "AUTO_BLOCKS_KONTEXT",
+ "FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
- "FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
+ "FluxKontextAutoBlocks",
+ "FluxKontextAutoDenoiseStep",
+ "FluxKontextBeforeDenoiseStep",
]
- _import_structure["modular_pipeline"] = ["FluxModularPipeline"]
+ _import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -45,13 +49,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
+ AUTO_BLOCKS_KONTEXT,
+ FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
+ FluxKontextAutoBlocks,
+ FluxKontextAutoDenoiseStep,
+ FluxKontextBeforeDenoiseStep,
)
- from .modular_pipeline import FluxModularPipeline
+ from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys
diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py
index 4272066309..daffec9865 100644
--- a/src/diffusers/modular_pipelines/flux/before_denoise.py
+++ b/src/diffusers/modular_pipelines/flux/before_denoise.py
@@ -13,12 +13,12 @@
# limitations under the License.
import inspect
-from typing import Any, List, Optional, Tuple, Union
+from typing import List, Optional, Union
import numpy as np
import torch
-from ...models import AutoencoderKL
+from ...pipelines import FluxPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
@@ -104,48 +104,6 @@ def calculate_shift(
return mu
-# Adapted from the original implementation.
-def prepare_latents_img2img(
- vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator
-):
- if isinstance(generator, list) and len(generator) != batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
- )
-
- vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
- latent_channels = vae.config.latent_channels
-
- # VAE applies 8x compression on images but we must also account for packing which requires
- # latent height and width to be divisible by 2.
- height = 2 * (int(height) // (vae_scale_factor * 2))
- width = 2 * (int(width) // (vae_scale_factor * 2))
- shape = (batch_size, num_channels_latents, height, width)
- latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
-
- image = image.to(device=device, dtype=dtype)
- if image.shape[1] != latent_channels:
- image_latents = _encode_vae_image(image=image, generator=generator)
- else:
- image_latents = image
- if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
- # expand init_latents for batch_size
- additional_image_per_prompt = batch_size // image_latents.shape[0]
- image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
- elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
- raise ValueError(
- f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
- )
- else:
- image_latents = torch.cat([image_latents], dim=0)
-
- noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = scheduler.scale_noise(image_latents, timestep, noise)
- latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
- return latents, latent_image_ids
-
-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -160,43 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
-def _pack_latents(latents, batch_size, num_channels_latents, height, width):
- latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
- latents = latents.permute(0, 2, 4, 1, 3, 5)
- latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
-
- return latents
-
-
-def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
- latent_image_ids = torch.zeros(height, width, 3)
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
-
- latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
-
- latent_image_ids = latent_image_ids.reshape(
- latent_image_id_height * latent_image_id_width, latent_image_id_channels
- )
-
- return latent_image_ids.to(device=device, dtype=dtype)
-
-
-# Cannot use "# Copied from" because it introduces weird indentation errors.
-def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
- if isinstance(generator, list):
- image_latents = [
- retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
- ]
- image_latents = torch.cat(image_latents, dim=0)
- else:
- image_latents = retrieve_latents(vae.encode(image), generator=generator)
-
- image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
-
- return image_latents
-
-
def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
@@ -231,92 +152,6 @@ def _get_initial_timesteps_and_optionals(
return timesteps, num_inference_steps, sigmas, guidance
-class FluxInputStep(ModularPipelineBlocks):
- model_name = "flux"
-
- @property
- def description(self) -> str:
- return (
- "Input processing step that:\n"
- " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
- " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
- "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
- "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
- "have a final batch_size of batch_size * num_images_per_prompt."
- )
-
- @property
- def inputs(self) -> List[InputParam]:
- return [
- InputParam("num_images_per_prompt", default=1),
- InputParam(
- "prompt_embeds",
- required=True,
- type_hint=torch.Tensor,
- description="Pre-generated text embeddings. Can be generated from text_encoder step.",
- ),
- InputParam(
- "pooled_prompt_embeds",
- type_hint=torch.Tensor,
- description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
- ),
- # TODO: support negative embeddings?
- ]
-
- @property
- def intermediate_outputs(self) -> List[str]:
- return [
- OutputParam(
- "batch_size",
- type_hint=int,
- description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
- ),
- OutputParam(
- "dtype",
- type_hint=torch.dtype,
- description="Data type of model tensor inputs (determined by `prompt_embeds`)",
- ),
- OutputParam(
- "prompt_embeds",
- type_hint=torch.Tensor,
- description="text embeddings used to guide the image generation",
- ),
- OutputParam(
- "pooled_prompt_embeds",
- type_hint=torch.Tensor,
- description="pooled text embeddings used to guide the image generation",
- ),
- # TODO: support negative embeddings?
- ]
-
- def check_inputs(self, components, block_state):
- if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
- if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
- raise ValueError(
- "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
- f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
- f" {block_state.pooled_prompt_embeds.shape}."
- )
-
- @torch.no_grad()
- def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
- # TODO: consider adding negative embeddings?
- block_state = self.get_block_state(state)
- self.check_inputs(components, block_state)
-
- block_state.batch_size = block_state.prompt_embeds.shape[0]
- block_state.dtype = block_state.prompt_embeds.dtype
-
- _, seq_len, _ = block_state.prompt_embeds.shape
- block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
- block_state.prompt_embeds = block_state.prompt_embeds.view(
- block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
- )
- self.set_block_state(state, block_state)
-
- return components, state
-
-
class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux"
@@ -385,6 +220,10 @@ class FluxSetTimestepsStep(ModularPipelineBlocks):
block_state.sigmas = sigmas
block_state.guidance = guidance
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ components.scheduler.set_begin_index(0)
+
self.set_block_state(state, block_state)
return components, state
@@ -428,11 +267,6 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
- OutputParam(
- "latent_timestep",
- type_hint=torch.Tensor,
- description="The timestep that represents the initial noise level for image-to-image generation",
- ),
OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
]
@@ -480,8 +314,6 @@ class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
block_state.sigmas = sigmas
block_state.guidance = guidance
- block_state.latent_timestep = timesteps[:1].repeat(batch_size)
-
self.set_block_state(state, block_state)
return components, state
@@ -520,11 +352,6 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
OutputParam(
"latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
),
- OutputParam(
- "latent_image_ids",
- type_hint=torch.Tensor,
- description="IDs computed from the image sequence needed for RoPE",
- ),
]
@staticmethod
@@ -548,20 +375,13 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
generator,
latents=None,
):
- # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
- # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
- # to go with the "# Copied from ..." approach. Or maybe there's a way?
-
- # VAE applies 8x compression on images but we must also account for packing which requires
- # latent height and width to be divisible by 2.
height = 2 * (int(height) // (comp.vae_scale_factor * 2))
width = 2 * (int(width) // (comp.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is not None:
- latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
- return latents.to(device=device, dtype=dtype), latent_image_ids
+ return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -569,26 +389,23 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
+ # TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
- latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)
- latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
-
- return latents, latent_image_ids
+ return latents
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
-
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
- block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
batch_size = block_state.batch_size * block_state.num_images_per_prompt
- block_state.latents, block_state.latent_image_ids = self.prepare_latents(
+ block_state.latents = self.prepare_latents(
components,
batch_size,
block_state.num_channels_latents,
@@ -608,82 +425,194 @@ class FluxPrepareLatentsStep(ModularPipelineBlocks):
class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux"
- @property
- def expected_components(self) -> List[ComponentSpec]:
- return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
-
@property
def description(self) -> str:
- return "Step that prepares the latents for the image-to-image generation process"
+ return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`,"
+ " `prepare_latents`. Both noise and image latents should already be patchified."
@property
- def inputs(self) -> List[Tuple[str, Any]]:
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def inputs(self) -> List[InputParam]:
return [
- InputParam("height", type_hint=int),
- InputParam("width", type_hint=int),
- InputParam("latents", type_hint=Optional[torch.Tensor]),
- InputParam("num_images_per_prompt", type_hint=int, default=1),
- InputParam("generator"),
InputParam(
- "image_latents",
+ name="latents",
required=True,
type_hint=torch.Tensor,
- description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
+ description="The initial random noised, can be generated in prepare latent step.",
),
InputParam(
- "latent_timestep",
+ name="image_latents",
required=True,
type_hint=torch.Tensor,
- description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
+ description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.",
),
InputParam(
- "batch_size",
+ name="timesteps",
required=True,
- type_hint=int,
- description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
- InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
- "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
- ),
- OutputParam(
- "latent_image_ids",
+ name="initial_noise",
type_hint=torch.Tensor,
- description="IDs computed from the image sequence needed for RoPE",
+ description="The initial random noised used for inpainting denoising.",
),
]
+ @staticmethod
+ def check_inputs(image_latents, latents):
+ if image_latents.shape[0] != latents.shape[0]:
+ raise ValueError(
+ f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
+ )
+
+ if image_latents.ndim != 3:
+ raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
+
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
- block_state.device = components._execution_device
- block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
- block_state.num_channels_latents = components.num_channels_latents
- block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
- block_state.device = components._execution_device
+ self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents)
- # TODO: implement `check_inputs`
- batch_size = block_state.batch_size * block_state.num_images_per_prompt
- if block_state.latents is None:
- block_state.latents, block_state.latent_image_ids = prepare_latents_img2img(
- components.vae,
- components.scheduler,
- block_state.image_latents,
- block_state.latent_timestep,
- batch_size,
- block_state.num_channels_latents,
- block_state.height,
- block_state.width,
- block_state.dtype,
- block_state.device,
- block_state.generator,
- )
+ # prepare latent timestep
+ latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
+
+ # make copy of initial_noise
+ block_state.initial_noise = block_state.latents
+
+ # scale noise
+ block_state.latents = components.scheduler.scale_noise(
+ block_state.image_latents, latent_timestep, block_state.latents
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="height", required=True),
+ InputParam(name="width", required=True),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(name="image_height"),
+ InputParam(name="image_width"),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="prompt_embeds"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ name="txt_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
+ ),
+ OutputParam(
+ name="img_ids",
+ kwargs_type="denoiser_input_fields",
+ type_hint=List[int],
+ description="The sequence lengths of the image latents, used for RoPE calculation.",
+ ),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ prompt_embeds = block_state.prompt_embeds
+ device, dtype = prompt_embeds.device, prompt_embeds.dtype
+ block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
+ device=prompt_embeds.device, dtype=prompt_embeds.dtype
+ )
+
+ img_ids = None
+ if (
+ getattr(block_state, "image_height", None) is not None
+ and getattr(block_state, "image_width", None) is not None
+ ):
+ image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
+ image_latent_width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2))
+ img_ids = FluxPipeline._prepare_latent_image_ids(
+ None, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ img_ids[..., 0] = 1
+
+ height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
+ width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
+ latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)
+
+ if img_ids is not None:
+ latent_ids = torch.cat([latent_ids, img_ids], dim=0)
+
+ block_state.img_ids = latent_ids
self.set_block_state(state, block_state)
diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py
index ffa0a4456f..5a769df103 100644
--- a/src/diffusers/modular_pipelines/flux/denoise.py
+++ b/src/diffusers/modular_pipelines/flux/denoise.py
@@ -59,7 +59,7 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
),
InputParam(
"guidance",
- required=True,
+ required=False,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
@@ -76,18 +76,17 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
description="Pooled prompt embeddings",
),
InputParam(
- "text_ids",
+ "txt_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
- "latent_image_ids",
+ "img_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from image sequence needed for RoPE",
),
- # TODO: guidance
]
@torch.no_grad()
@@ -101,8 +100,8 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
- txt_ids=block_state.text_ids,
- img_ids=block_state.latent_image_ids,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
return_dict=False,
)[0]
block_state.noise_pred = noise_pred
@@ -110,6 +109,96 @@ class FluxLoopDenoiser(ModularPipelineBlocks):
return components, block_state
+class FluxKontextLoopDenoiser(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents for Flux Kontext. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("joint_attention_kwargs"),
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ type_hint=torch.Tensor,
+ description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "guidance",
+ required=False,
+ type_hint=torch.Tensor,
+ description="Guidance scale as a tensor",
+ ),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Prompt embeddings",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pooled prompt embeddings",
+ ),
+ InputParam(
+ "txt_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from text sequence needed for RoPE",
+ ),
+ InputParam(
+ "img_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from latent sequence needed for RoPE",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ latents = block_state.latents
+ latent_model_input = latents
+ image_latents = block_state.image_latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)
+
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ noise_pred = components.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=block_state.guidance,
+ encoder_hidden_states=block_state.prompt_embeds,
+ pooled_projections=block_state.pooled_prompt_embeds,
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
+ txt_ids=block_state.txt_ids,
+ img_ids=block_state.img_ids,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+ block_state.noise_pred = noise_pred
+
+ return components, block_state
+
+
class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"
@@ -195,9 +284,6 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
- # We set the index here to remove DtoH sync, helpful especially during compilation.
- # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
- components.scheduler.set_begin_index(0)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
@@ -225,3 +311,20 @@ class FluxDenoiseStep(FluxDenoiseLoopWrapper):
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)
+
+
+class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
+ block_names = ["denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
+ " - `FluxKontextLoopDenoiser`\n"
+ " - `FluxLoopAfterDenoiser`\n"
+ "This block supports both text2image and img2img tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py
index 8c49990280..f0314d4771 100644
--- a/src/diffusers/modular_pipelines/flux/encoders.py
+++ b/src/diffusers/modular_pipelines/flux/encoders.py
@@ -20,12 +20,12 @@ import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from ...configuration_utils import FrozenDict
-from ...image_processor import VaeImageProcessor
+from ...image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
-from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline
@@ -67,17 +67,31 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")
-class FluxVaeEncoderStep(ModularPipelineBlocks):
+def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode)
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode)
+
+ image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
+
+ return image_latents
+
+
+class FluxProcessImagesInputStep(ModularPipelineBlocks):
model_name = "flux"
@property
def description(self) -> str:
- return "Vae Encoder step that encode the input image into a latent representation"
+ return "Image Preprocess step."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
- ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
@@ -88,68 +102,181 @@ class FluxVaeEncoderStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.resized_image is None and block_state.image is None:
+ raise ValueError("`resized_image` and `image` cannot be None at the same time")
+
+ if block_state.resized_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ else:
+ width, height = block_state.resized_image[0].size
+ image = block_state.resized_image
+
+ block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Image preprocess step for Flux Kontext. The preprocessed image goes to the VAE.\n"
+ "Kontext works as a T2I model, too, in case no input image is provided."
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
return [
- InputParam("image", required=True),
- InputParam("height"),
- InputParam("width"),
- InputParam("generator"),
- InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
- InputParam(
- "preprocess_kwargs",
- type_hint=Optional[dict],
- description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
),
]
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("image"), InputParam("_auto_resize", type_hint=bool, default=True)]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam(name="processed_image")]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState):
+ from ...pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
+
+ block_state = self.get_block_state(state)
+ images = block_state.image
+
+ if images is None:
+ block_state.processed_image = None
+
+ else:
+ multiple_of = components.image_processor.config.vae_scale_factor
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if is_valid_image(images):
+ images = [images]
+
+ img = images[0]
+ image_height, image_width = components.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ _auto_resize = block_state._auto_resize
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ images = components.image_processor.resize(images, image_height, image_width)
+ block_state.processed_image = components.image_processor.preprocess(images, image_height, image_width)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"
+ ):
+ """Initialize a VAE encoder step for converting images to latent representations.
+
+ Both the input and output names are configurable so this block can be configured to process to different image
+ inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents").
+
+ Args:
+ input_name (str, optional): Name of the input image tensor. Defaults to "processed_image".
+ Examples: "processed_image" or "processed_control_image"
+ output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents".
+ Examples: "image_latents" or "control_image_latents"
+ sample_mode (str, optional): Sampling mode to be used.
+
+ Examples:
+ # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep()
+
+ # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep(
+ input_name="processed_control_image", output_name="control_image_latents"
+ )
+ """
+ self._image_input_name = input_name
+ self._image_latents_output_name = output_name
+ self.sample_mode = sample_mode
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ components = [ComponentSpec("vae", AutoencoderKL)]
+ return components
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [InputParam(self._image_input_name), InputParam("generator")]
+ return inputs
+
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
- "image_latents",
+ self._image_latents_output_name,
type_hint=torch.Tensor,
- description="The latents representing the reference image for image-to-image/inpainting generation",
+ description="The latents representing the reference image",
)
]
- @staticmethod
- # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae
- def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator):
- if isinstance(generator, list):
- image_latents = [
- retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0])
- ]
- image_latents = torch.cat(image_latents, dim=0)
- else:
- image_latents = retrieve_latents(vae.encode(image), generator=generator)
-
- image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor
-
- return image_latents
-
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
- block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
- block_state.device = components._execution_device
- block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ image = getattr(block_state, self._image_input_name)
- block_state.image = components.image_processor.preprocess(
- block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
- )
- block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
+ if image is None:
+ setattr(block_state, self._image_latents_output_name, None)
+ else:
+ device = components._execution_device
+ dtype = components.vae.dtype
+ image = image.to(device=device, dtype=dtype)
- block_state.batch_size = block_state.image.shape[0]
-
- # if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
- if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
- raise ValueError(
- f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
- f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
+ # Encode image into latents
+ image_latents = encode_vae_image(
+ image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode
)
-
- block_state.image_latents = self._encode_vae_image(
- components.vae, image=block_state.image, generator=block_state.generator
- )
+ setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
@@ -161,7 +288,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
@property
def description(self) -> str:
- return "Text Encoder step that generate text_embeddings to guide the video generation"
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
@property
def expected_components(self) -> List[ComponentSpec]:
@@ -172,15 +299,12 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
ComponentSpec("tokenizer_2", T5TokenizerFast),
]
- @property
- def expected_configs(self) -> List[ConfigSpec]:
- return []
-
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
+ InputParam("max_sequence_length", type_hint=int, default=512, required=False),
InputParam("joint_attention_kwargs"),
]
@@ -189,19 +313,16 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
return [
OutputParam(
"prompt_embeds",
+ kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="pooled text embeddings used to guide the image generation",
),
- OutputParam(
- "text_ids",
- type_hint=torch.Tensor,
- description="ids from the text sequence for RoPE",
- ),
]
@staticmethod
@@ -212,16 +333,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
@staticmethod
def _get_t5_prompt_embeds(
- components,
- prompt: Union[str, List[str]],
- num_images_per_prompt: int,
- max_sequence_length: int,
- device: torch.device,
+ components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device
):
dtype = components.text_encoder_2.dtype
-
prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
@@ -247,23 +362,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- _, seq_len, _ = prompt_embeds.shape
-
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
-
return prompt_embeds
@staticmethod
- def _get_clip_prompt_embeds(
- components,
- prompt: Union[str, List[str]],
- num_images_per_prompt: int,
- device: torch.device,
- ):
+ def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device):
prompt = [prompt] if isinstance(prompt, str) else prompt
- batch_size = len(prompt)
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
@@ -293,10 +396,6 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt_embeds = prompt_embeds.pooler_output
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
- # duplicate text embeddings for each generation per prompt, using mps friendly method
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
-
return prompt_embeds
@staticmethod
@@ -305,34 +404,11 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
device: Optional[torch.device] = None,
- num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
lora_scale: Optional[float] = None,
):
- r"""
- Encodes the prompt into text encoder hidden states.
-
- Args:
- prompt (`str` or `List[str]`, *optional*):
- prompt to be encoded
- prompt_2 (`str` or `List[str]`, *optional*):
- The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
- used in all text-encoders
- device: (`torch.device`):
- torch device
- num_images_per_prompt (`int`):
- number of images that should be generated per prompt
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
- If not provided, pooled text embeddings will be generated from `prompt` input argument.
- lora_scale (`float`, *optional*):
- A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
- """
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
@@ -357,12 +433,10 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
components,
prompt=prompt,
device=device,
- num_images_per_prompt=num_images_per_prompt,
)
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
components,
prompt=prompt_2,
- num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
)
@@ -377,10 +451,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
- dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
- text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
-
- return prompt_embeds, pooled_prompt_embeds, text_ids
+ return prompt_embeds, pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
@@ -396,14 +467,14 @@ class FluxTextEncoderStep(ModularPipelineBlocks):
if block_state.joint_attention_kwargs is not None
else None
)
- (block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
+ block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt(
components,
prompt=block_state.prompt,
prompt_2=None,
prompt_embeds=None,
pooled_prompt_embeds=None,
device=block_state.device,
- num_images_per_prompt=1, # TODO: hardcoded for now.
+ max_sequence_length=block_state.max_sequence_length,
lora_scale=block_state.text_encoder_lora_scale,
)
diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py
new file mode 100644
index 0000000000..8309eebfeb
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/inputs.py
@@ -0,0 +1,363 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import torch
+
+from ...pipelines import FluxPipeline
+from ...utils import logging
+from ..modular_pipeline import ModularPipelineBlocks, PipelineState
+from ..modular_pipeline_utils import InputParam, OutputParam
+
+# TODO: consider making these common utilities for modular if they are not pipeline-specific.
+from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__)
+
+
+class FluxTextInputStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Text input processing step that standardizes text embeddings for the pipeline.\n"
+ "This step:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ kwargs_type="denoiser_input_fields",
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="denoiser_input_fields",
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
+ raise ValueError(
+ "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
+ f" {block_state.pooled_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # TODO: consider adding negative embeddings?
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+ pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt)
+ block_state.pooled_prompt_embeds = pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# Adapted from `QwenImageInputsDynamicStep`
+class FluxInputsDynamicStep(ModularPipelineBlocks):
+ model_name = "flux"
+
+ def __init__(
+ self,
+ image_latent_inputs: List[str] = ["image_latents"],
+ additional_batch_inputs: List[str] = [],
+ ):
+ if not isinstance(image_latent_inputs, list):
+ image_latent_inputs = [image_latent_inputs]
+ if not isinstance(additional_batch_inputs, list):
+ additional_batch_inputs = [additional_batch_inputs]
+
+ self._image_latent_inputs = image_latent_inputs
+ self._additional_batch_inputs = additional_batch_inputs
+ super().__init__()
+
+ @property
+ def description(self) -> str:
+ # Functionality section
+ summary_section = (
+ "Input processing step that:\n"
+ " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
+ " 2. For additional batch inputs: Expands batch dimensions to match final batch size"
+ )
+
+ # Inputs info
+ inputs_info = ""
+ if self._image_latent_inputs or self._additional_batch_inputs:
+ inputs_info = "\n\nConfigured inputs:"
+ if self._image_latent_inputs:
+ inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
+ if self._additional_batch_inputs:
+ inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
+
+ # Placement guidance
+ placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
+
+ return summary_section + inputs_info + placement_section
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="num_images_per_prompt", default=1),
+ InputParam(name="batch_size", required=True),
+ InputParam(name="height"),
+ InputParam(name="width"),
+ ]
+
+ # Add image latent inputs
+ for image_latent_input_name in self._image_latent_inputs:
+ inputs.append(InputParam(name=image_latent_input_name))
+
+ # Add additional batch inputs
+ for input_name in self._additional_batch_inputs:
+ inputs.append(InputParam(name=input_name))
+
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ block_state.height = block_state.height or height
+ block_state.width = block_state.width or width
+
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
+ model_name = "flux-kontext"
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # Process image latent inputs (height/width calculation, patchify, and batch expansion)
+ for image_latent_input_name in self._image_latent_inputs:
+ image_latent_tensor = getattr(block_state, image_latent_input_name)
+ if image_latent_tensor is None:
+ continue
+
+ # 1. Calculate height/width from latents
+ # Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
+ height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
+ # 2. Patchify the image latent tensor
+ # TODO: Implement patchifier for Flux.
+ latent_height, latent_width = image_latent_tensor.shape[2:]
+ image_latent_tensor = FluxPipeline._pack_latents(
+ image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width
+ )
+
+ # 3. Expand batch size
+ image_latent_tensor = repeat_tensor_to_batch_size(
+ input_name=image_latent_input_name,
+ input_tensor=image_latent_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, image_latent_input_name, image_latent_tensor)
+
+ # Process additional batch inputs (only batch expansion)
+ for input_name in self._additional_batch_inputs:
+ input_tensor = getattr(block_state, input_name)
+ if input_tensor is None:
+ continue
+
+ # Only expand batch size
+ input_tensor = repeat_tensor_to_batch_size(
+ input_name=input_name,
+ input_tensor=input_tensor,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ batch_size=block_state.batch_size,
+ )
+
+ setattr(block_state, input_name, input_tensor)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxKontextSetResolutionStep(ModularPipelineBlocks):
+ model_name = "flux-kontext"
+
+ def description(self):
+ return (
+ "Determines the height and width to be used during the subsequent computations.\n"
+ "It should always be placed _before_ the latent preparation step."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ inputs = [
+ InputParam(name="height"),
+ InputParam(name="width"),
+ InputParam(name="max_area", type_hint=int, default=1024**2),
+ ]
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="height", type_hint=int, description="The height of the initial noisy latents"),
+ OutputParam(name="width", type_hint=int, description="The width of the initial noisy latents"),
+ ]
+
+ @staticmethod
+ def check_inputs(height, width, vae_scale_factor):
+ if height is not None and height % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
+
+ if width is not None and width % (vae_scale_factor * 2) != 0:
+ raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
+
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ self.check_inputs(height, width, components.vae_scale_factor)
+
+ original_height, original_width = height, width
+ max_area = block_state.max_area
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = components.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ block_state.height = height
+ block_state.width = width
+
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py
index 37895bddbf..a80bc2a5f7 100644
--- a/src/diffusers/modular_pipelines/flux/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py
@@ -18,21 +18,49 @@ from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
- FluxInputStep,
+ FluxKontextRoPEInputsStep,
FluxPrepareLatentsStep,
+ FluxRoPEInputsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
-from .denoise import FluxDenoiseStep
-from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep
+from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
+from .encoders import (
+ FluxKontextProcessImagesInputStep,
+ FluxProcessImagesInputStep,
+ FluxTextEncoderStep,
+ FluxVaeEncoderDynamicStep,
+)
+from .inputs import (
+ FluxInputsDynamicStep,
+ FluxKontextInputsDynamicStep,
+ FluxKontextSetResolutionStep,
+ FluxTextInputStep,
+)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# vae encoder (run before before_denoise)
+FluxImg2ImgVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
+)
+
+
+class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = FluxImg2ImgVaeEncoderBlocks.values()
+ block_names = FluxImg2ImgVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
- block_classes = [FluxVaeEncoderStep]
+ block_classes = [FluxImg2ImgVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@@ -41,52 +69,89 @@ class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
- + " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided."
- + " - if `image` is provided, step will be skipped."
+ + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
)
-# before_denoise: text2img, img2img
-class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
- block_classes = [
- FluxInputStep,
- FluxPrepareLatentsStep,
- FluxSetTimestepsStep,
- ]
- block_names = ["input", "prepare_latents", "set_timesteps"]
+# Flux Kontext vae encoder (run before before_denoise)
+
+FluxKontextVaeEncoderBlocks = InsertableDict(
+ [("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
+)
+
+
+class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = FluxKontextVaeEncoderBlocks.values()
+ block_names = FluxKontextVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
+
+
+class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextVaeEncoderStep]
+ block_names = ["img2img"]
+ block_trigger_inputs = ["image"]
@property
def description(self):
return (
- "Before denoise step that prepare the inputs for the denoise step.\n"
- + "This is a sequential pipeline blocks:\n"
- + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
- + " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
- + " - `FluxSetTimestepsStep` is used to set the timesteps\n"
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for img2img tasks.\n"
+ + " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if `image` is not provided, step will be skipped."
)
+# before_denoise: text2img
+FluxBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
+class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxBeforeDenoiseBlocks.values()
+ block_names = FluxBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
+
+
# before_denoise: img2img
+FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ]
+)
+
+
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
- block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep]
- block_names = ["input", "set_timesteps", "prepare_latents"]
+ block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
+ block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
@property
def description(self):
- return (
- "Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
- + "This is a sequential pipeline blocks:\n"
- + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
- + " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n"
- + " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
- )
+ return "Before denoise step that prepare the inputs for the denoise step for img2img task."
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
- block_classes = [FluxBeforeDenoiseStep, FluxImg2ImgBeforeDenoiseStep]
- block_names = ["text2image", "img2img"]
- block_trigger_inputs = [None, "image_latents"]
+ model_name = "flux-kontext"
+ block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
@property
def description(self):
@@ -98,6 +163,44 @@ class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
)
+# before_denoise: FluxKontext
+
+FluxKontextBeforeDenoiseBlocks = InsertableDict(
+ [
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ]
+)
+
+
+class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = FluxKontextBeforeDenoiseBlocks.values()
+ block_names = FluxKontextBeforeDenoiseBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step\n"
+ "for img2img/text2img task for Flux Kontext."
+ )
+
+
+class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2image.\n"
+ + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ + " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ )
+
+
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
@@ -113,7 +216,24 @@ class FluxAutoDenoiseStep(AutoPipelineBlocks):
)
-# decode: all task (text2img, img2img, inpainting)
+# denoise: Flux Kontext
+
+
+class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextDenoiseStep]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents for Flux Kontext. "
+ "This is a auto pipeline block that works for text2image and img2img tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
+ )
+
+
+# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep]
block_names = ["non-inpaint"]
@@ -124,16 +244,143 @@ class FluxAutoDecodeStep(AutoPipelineBlocks):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
-# text2image
-class FluxAutoBlocks(SequentialPipelineBlocks):
- block_classes = [
- FluxTextEncoderStep,
- FluxAutoVaeEncoderStep,
- FluxAutoBeforeDenoiseStep,
- FluxAutoDenoiseStep,
- FluxAutoDecodeStep,
+# inputs: text2image/img2img
+FluxImg2ImgBlocks = InsertableDict(
+ [("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
+)
+
+
+class FluxImg2ImgInputStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = FluxImg2ImgBlocks.values()
+ block_names = FluxImg2ImgBlocks.keys()
+
+ @property
+ def description(self):
+ return "Input step that prepares the inputs for the img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+
+
+class FluxAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
+ block_names = ["img2img", "text2image"]
+ block_trigger_inputs = ["image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
+ )
+
+
+# inputs: Flux Kontext
+
+FluxKontextBlocks = InsertableDict(
+ [
+ ("set_resolution", FluxKontextSetResolutionStep()),
+ ("text_inputs", FluxTextInputStep()),
+ ("additional_inputs", FluxKontextInputsDynamicStep()),
]
- block_names = ["text_encoder", "image_encoder", "before_denoise", "denoise", "decoder"]
+)
+
+
+class FluxKontextInputStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = FluxKontextBlocks.values()
+ block_names = FluxKontextBlocks.keys()
+
+ @property
+ def description(self):
+ return (
+ "Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
+ " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
+ " - update height/width based `image_latents`, patchify `image_latents`."
+ )
+
+
+class FluxKontextAutoInputStep(AutoPipelineBlocks):
+ block_classes = [FluxKontextInputStep, FluxTextInputStep]
+ # block_classes = [FluxKontextInputStep]
+ block_names = ["img2img", "text2img"]
+ # block_names = ["img2img"]
+ block_trigger_inputs = ["image_latents", None]
+ # block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
+ " This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ + " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
+ + " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
+ )
+
+
+class FluxCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux"
+ block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "flux-kontext"
+ block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings."
+ )
+
+
+# Auto blocks (text2image and img2img)
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxAutoVaeEncoderStep()),
+ ("denoise", FluxCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+AUTO_BLOCKS_KONTEXT = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep()),
+ ("image_encoder", FluxKontextAutoVaeEncoderStep()),
+ ("denoise", FluxKontextCoreDenoiseStep()),
+ ("decode", FluxDecodeStep()),
+ ]
+)
+
+
+class FluxAutoBlocks(SequentialPipelineBlocks):
+ model_name = "flux"
+
+ block_classes = AUTO_BLOCKS.values()
+ block_names = AUTO_BLOCKS.keys()
@property
def description(self):
@@ -144,38 +391,56 @@ class FluxAutoBlocks(SequentialPipelineBlocks):
)
+class FluxKontextAutoBlocks(FluxAutoBlocks):
+ model_name = "flux-kontext"
+
+ block_classes = AUTO_BLOCKS_KONTEXT.values()
+ block_names = AUTO_BLOCKS_KONTEXT.keys()
+
+
TEXT2IMAGE_BLOCKS = InsertableDict(
[
- ("text_encoder", FluxTextEncoderStep),
- ("input", FluxInputStep),
- ("prepare_latents", FluxPrepareLatentsStep),
- ("set_timesteps", FluxSetTimestepsStep),
- ("denoise", FluxDenoiseStep),
- ("decode", FluxDecodeStep),
+ ("text_encoder", FluxTextEncoderStep()),
+ ("input", FluxTextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
- ("text_encoder", FluxTextEncoderStep),
- ("image_encoder", FluxVaeEncoderStep),
- ("input", FluxInputStep),
- ("set_timesteps", FluxImg2ImgSetTimestepsStep),
- ("prepare_latents", FluxImg2ImgPrepareLatentsStep),
- ("denoise", FluxDenoiseStep),
- ("decode", FluxDecodeStep),
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep()),
+ ("input", FluxImg2ImgInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxImg2ImgSetTimestepsStep()),
+ ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
+ ("prepare_rope_inputs", FluxRoPEInputsStep()),
+ ("denoise", FluxDenoiseStep()),
+ ("decode", FluxDecodeStep()),
]
)
-AUTO_BLOCKS = InsertableDict(
+FLUX_KONTEXT_BLOCKS = InsertableDict(
[
- ("text_encoder", FluxTextEncoderStep),
- ("image_encoder", FluxAutoVaeEncoderStep),
- ("before_denoise", FluxAutoBeforeDenoiseStep),
- ("denoise", FluxAutoDenoiseStep),
- ("decode", FluxAutoDecodeStep),
+ ("text_encoder", FluxTextEncoderStep()),
+ ("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
+ ("input", FluxKontextInputStep()),
+ ("prepare_latents", FluxPrepareLatentsStep()),
+ ("set_timesteps", FluxSetTimestepsStep()),
+ ("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
+ ("denoise", FluxKontextDenoiseStep()),
+ ("decode", FluxDecodeStep()),
]
)
-
-ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
+ALL_BLOCKS = {
+ "text2image": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "auto": AUTO_BLOCKS,
+ "auto_kontext": AUTO_BLOCKS_KONTEXT,
+ "kontext": FLUX_KONTEXT_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
index e97445d411..d8158f5d4f 100644
--- a/src/diffusers/modular_pipelines/flux/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
@@ -25,13 +25,11 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
"""
A ModularPipeline for Flux.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "FluxAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@@ -57,3 +55,13 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
if getattr(self, "transformer", None):
num_channels_latents = self.transformer.config.in_channels // 4
return num_channels_latents
+
+
+class FluxKontextModularPipeline(FluxModularPipeline):
+ """
+ A ModularPipeline for Flux Kontext.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "FluxKontextAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/mellon_node_utils.py b/src/diffusers/modular_pipelines/mellon_node_utils.py
new file mode 100644
index 0000000000..a405aebee2
--- /dev/null
+++ b/src/diffusers/modular_pipelines/mellon_node_utils.py
@@ -0,0 +1,763 @@
+import json
+import logging
+import os
+
+# Simple typed wrapper for parameter overrides
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from huggingface_hub import create_repo, hf_hub_download
+from huggingface_hub.utils import (
+ EntryNotFoundError,
+ HfHubHTTPError,
+ RepositoryNotFoundError,
+ RevisionNotFoundError,
+ validate_hf_hub_args,
+)
+
+from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, PushToHubMixin, extract_commit_hash
+from .modular_pipeline import ModularPipelineBlocks
+
+
+logger = logging.getLogger(__name__)
+
+
+SUPPORTED_NODE_TYPES = {"controlnet", "vae_encoder", "denoise", "text_encoder", "decoder"}
+
+
+# Mellon Input Parameters (runtime parameters, not models)
+MELLON_INPUT_PARAMS = {
+ # controlnet
+ "control_image": {
+ "label": "Control Image",
+ "type": "image",
+ "display": "input",
+ },
+ "controlnet_conditioning_scale": {
+ "label": "Scale",
+ "type": "float",
+ "default": 0.5,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_end": {
+ "label": "End",
+ "type": "float",
+ "default": 1.0,
+ "min": 0,
+ "max": 1,
+ },
+ "control_guidance_start": {
+ "label": "Start",
+ "type": "float",
+ "default": 0.0,
+ "min": 0,
+ "max": 1,
+ },
+ "controlnet": {
+ "label": "Controlnet",
+ "type": "custom_controlnet",
+ "display": "input",
+ },
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "input",
+ "type": "embeddings",
+ },
+ "image": {
+ "label": "Image",
+ "type": "image",
+ "display": "input",
+ },
+ "negative_prompt": {
+ "label": "Negative Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "prompt": {
+ "label": "Prompt",
+ "type": "string",
+ "default": "",
+ "display": "textarea",
+ },
+ "guidance_scale": {
+ "label": "Guidance Scale",
+ "type": "float",
+ "display": "slider",
+ "default": 5,
+ "min": 1.0,
+ "max": 30.0,
+ "step": 0.1,
+ },
+ "height": {
+ "label": "Height",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "input",
+ "onChange": {False: ["height", "width"], True: ["strength"]},
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "input",
+ },
+ "num_inference_steps": {
+ "label": "Steps",
+ "type": "int",
+ "display": "slider",
+ "default": 25,
+ "min": 1,
+ "max": 100,
+ },
+ "seed": {
+ "label": "Seed",
+ "type": "int",
+ "display": "random",
+ "default": 0,
+ "min": 0,
+ "max": 4294967295,
+ },
+ "strength": {
+ "label": "Strength",
+ "type": "float",
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ "width": {
+ "label": "Width",
+ "type": "int",
+ "default": 1024,
+ "min": 64,
+ "step": 8,
+ },
+ "ip_adapter": {
+ "label": "IP Adapter",
+ "type": "custom_ip_adapter",
+ "display": "input",
+ },
+}
+
+# Mellon Model Parameters (diffusers_auto_model types)
+MELLON_MODEL_PARAMS = {
+ "scheduler": {
+ "label": "Scheduler",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "text_encoders": {
+ "label": "Text Encoders",
+ "type": "diffusers_auto_models",
+ "display": "input",
+ },
+ "unet": {
+ "label": "Unet",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ "onSignal": {
+ "action": "signal",
+ "target": "guider",
+ },
+ },
+ "guider": {
+ "label": "Guider",
+ "display": "input",
+ "type": "custom_guider",
+ "onChange": {False: ["guidance_scale"], True: []},
+ },
+ "vae": {
+ "label": "VAE",
+ "display": "input",
+ "type": "diffusers_auto_model",
+ },
+ "controlnet": {
+ "label": "Controlnet Model",
+ "type": "diffusers_auto_model",
+ "display": "input",
+ },
+}
+
+# Mellon Output Parameters (display = "output")
+MELLON_OUTPUT_PARAMS = {
+ "embeddings": {
+ "label": "Text Embeddings",
+ "display": "output",
+ "type": "embeddings",
+ },
+ "images": {
+ "label": "Images",
+ "type": "image",
+ "display": "output",
+ },
+ "image_latents": {
+ "label": "Image Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents": {
+ "label": "Latents",
+ "type": "latents",
+ "display": "output",
+ },
+ "latents_preview": {
+ "label": "Latents Preview",
+ "display": "output",
+ "type": "latent",
+ },
+ "controlnet_out": {
+ "label": "Controlnet",
+ "display": "output",
+ "type": "controlnet",
+ },
+}
+
+
+# Default param selections per supported node_type
+# from MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS.
+NODE_TYPE_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ # optional image prompt input supported in embeddings node
+ "image",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
+
+
+@dataclass(frozen=True)
+class MellonParam:
+ name: str
+ label: str
+ type: str
+ display: Optional[str] = None
+ default: Any = None
+ min: Optional[float] = None
+ max: Optional[float] = None
+ step: Optional[float] = None
+ options: Any = None
+ value: Any = None
+ fieldOptions: Optional[Dict[str, Any]] = None
+ onChange: Any = None
+ onSignal: Any = None
+ _map_to_input: Any = None # the block input name this parameter maps to
+
+ def to_dict(self) -> Dict[str, Any]:
+ data = asdict(self)
+ return {k: v for k, v in data.items() if not k.startswith("_") and v is not None}
+
+
+@dataclass
+class MellonNodeConfig(PushToHubMixin):
+ """
+ A MellonNodeConfig is a base class to build Mellon nodes UI with modular diffusers.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ inputs: List[Union[str, MellonParam]]
+ model_inputs: List[Union[str, MellonParam]]
+ outputs: List[Union[str, MellonParam]]
+ blocks_names: list[str]
+ node_type: str
+ config_name = "mellon_config.json"
+
+ def __post_init__(self):
+ if isinstance(self.inputs, list):
+ self.inputs = self._resolve_params_list(self.inputs, MELLON_INPUT_PARAMS)
+ if isinstance(self.model_inputs, list):
+ self.model_inputs = self._resolve_params_list(self.model_inputs, MELLON_MODEL_PARAMS)
+ if isinstance(self.outputs, list):
+ self.outputs = self._resolve_params_list(self.outputs, MELLON_OUTPUT_PARAMS)
+
+ @staticmethod
+ def _resolve_params_list(
+ params: List[Union[str, MellonParam]], default_map: Dict[str, Dict[str, Any]]
+ ) -> Dict[str, Dict[str, Any]]:
+ def _resolve_param(
+ param: Union[str, MellonParam], default_params_map: Dict[str, Dict[str, Any]]
+ ) -> Tuple[str, Dict[str, Any]]:
+ if isinstance(param, str):
+ if param not in default_params_map:
+ raise ValueError(f"Unknown param '{param}', please define a `MellonParam` object instead")
+ return param, default_params_map[param].copy()
+ elif isinstance(param, MellonParam):
+ param_dict = param.to_dict()
+ param_name = param_dict.pop("name")
+ return param_name, param_dict
+ else:
+ raise ValueError(
+ f"Unknown param type '{type(param)}', please use a string or a `MellonParam` object instead"
+ )
+
+ resolved = {}
+ for p in params:
+ logger.info(f" Resolving param: {p}")
+ name, cfg = _resolve_param(p, default_map)
+ if name in resolved:
+ raise ValueError(f"Duplicate param '{name}'")
+ resolved[name] = cfg
+ return resolved
+
+ @classmethod
+ @validate_hf_hub_args
+ def load_mellon_config(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ return_unused_kwargs=False,
+ return_commit_hash=False,
+ **kwargs,
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ r"""
+ Load a model or scheduler configuration.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
+ [`~ConfigMixin.save_config`].
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False):
+ Whether unused keyword arguments of the config are returned.
+ return_commit_hash (`bool`, *optional*, defaults to `False):
+ Whether the `commit_hash` of the loaded configuration are returned.
+
+ Returns:
+ `dict`:
+ A dictionary of all the parameters stored in a JSON configuration file.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ local_dir = kwargs.pop("local_dir", None)
+ local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ if cls.config_name is None:
+ raise ValueError(
+ "`self.config_name` is not defined. Note that one should not load a config from "
+ "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
+ )
+ if os.path.isfile(pretrained_model_name_or_path):
+ config_file = pretrained_model_name_or_path
+ elif os.path.isdir(pretrained_model_name_or_path):
+ if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
+ # Load from a PyTorch checkpoint
+ config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
+ else:
+ raise EnvironmentError(
+ f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
+ )
+ else:
+ try:
+ # Load from URL or cache if already cached
+ config_file = hf_hub_download(
+ pretrained_model_name_or_path,
+ filename=cls.config_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ local_dir=local_dir,
+ local_dir_use_symlinks=local_dir_use_symlinks,
+ )
+ except RepositoryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
+ " listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
+ " token having permission to this repo with `token` or log in with `hf auth login`."
+ )
+ except RevisionNotFoundError:
+ raise EnvironmentError(
+ f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
+ " this model name. Check the model page at"
+ f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
+ )
+ except EntryNotFoundError:
+ raise EnvironmentError(
+ f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
+ )
+ except HfHubHTTPError as err:
+ raise EnvironmentError(
+ "There was a specific connection error when trying to load"
+ f" {pretrained_model_name_or_path}:\n{err}"
+ )
+ except ValueError:
+ raise EnvironmentError(
+ f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
+ f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
+ f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
+ " run the library in offline mode at"
+ " 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
+ )
+ except EnvironmentError:
+ raise EnvironmentError(
+ f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
+ "'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
+ f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
+ f"containing a {cls.config_name} file"
+ )
+ try:
+ with open(config_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ config_dict = json.loads(text)
+
+ commit_hash = extract_commit_hash(config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError):
+ raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
+
+ if not (return_unused_kwargs or return_commit_hash):
+ return config_dict
+
+ outputs = (config_dict,)
+
+ if return_unused_kwargs:
+ outputs += (kwargs,)
+
+ if return_commit_hash:
+ outputs += (commit_hash,)
+
+ return outputs
+
+ def save_mellon_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save the Mellon node definition to a JSON file.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file is saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
+
+ self.to_json_file(output_config_file)
+ logger.info(f"Mellon node definition saved in {output_config_file}")
+
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", None)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+ subfolder = kwargs.pop("subfolder", None)
+
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ subfolder=subfolder,
+ )
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
+ """
+ Save the Mellon schema dictionary to a JSON file.
+
+ Args:
+ json_file_path (`str` or `os.PathLike`):
+ Path to the JSON file to save a configuration instance's parameters.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string())
+
+ def to_json_string(self) -> str:
+ """
+ Serializes this instance to a JSON string of the Mellon schema dict.
+
+ Args:
+ Returns:
+ `str`: String containing all the attributes that make up this configuration instance in JSON format.
+ """
+
+ mellon_dict = self.to_mellon_dict()
+ return json.dumps(mellon_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_mellon_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dict focusing on the Mellon schema fields only.
+
+ params is a single flat dict composed as: {**inputs, **model_inputs, **outputs}.
+ """
+ # inputs/model_inputs/outputs are already normalized dicts
+ merged_params = {}
+ merged_params.update(self.inputs or {})
+ merged_params.update(self.model_inputs or {})
+ merged_params.update(self.outputs or {})
+
+ return {
+ "node_type": self.node_type,
+ "blocks_names": self.blocks_names,
+ "params": merged_params,
+ }
+
+ @classmethod
+ def from_mellon_dict(cls, mellon_dict: Dict[str, Any]) -> "MellonNodeConfig":
+ """Create a config from a Mellon schema dict produced by to_mellon_dict().
+
+ Splits the flat params dict back into inputs/model_inputs/outputs using the known key spaces from
+ MELLON_INPUT_PARAMS / MELLON_MODEL_PARAMS / MELLON_OUTPUT_PARAMS. Unknown keys are treated as inputs by
+ default.
+ """
+ flat_params = mellon_dict.get("params", {})
+
+ inputs: Dict[str, Any] = {}
+ model_inputs: Dict[str, Any] = {}
+ outputs: Dict[str, Any] = {}
+
+ for param_name, param_dict in flat_params.items():
+ if param_dict.get("display", "") == "output":
+ outputs[param_name] = param_dict
+ elif param_dict.get("type", "") in ("diffusers_auto_model", "diffusers_auto_models"):
+ model_inputs[param_name] = param_dict
+ else:
+ inputs[param_name] = param_dict
+
+ return cls(
+ inputs=inputs,
+ model_inputs=model_inputs,
+ outputs=outputs,
+ blocks_names=mellon_dict.get("blocks_names", []),
+ node_type=mellon_dict.get("node_type"),
+ )
+
+ # YiYi Notes: not used yet
+ @classmethod
+ def from_blocks(cls, blocks: ModularPipelineBlocks, node_type: str) -> "MellonNodeConfig":
+ """
+ Create an instance from a ModularPipeline object. If a preset exists in NODE_TYPE_PARAMS_MAP for the node_type,
+ use it; otherwise fall back to deriving lists from the pipeline's expected inputs/components/outputs.
+ """
+ if node_type not in NODE_TYPE_PARAMS_MAP:
+ raise ValueError(f"Node type {node_type} not supported")
+
+ blocks_names = list(blocks.sub_blocks.keys())
+
+ default_node_config = NODE_TYPE_PARAMS_MAP[node_type]
+ inputs_list: List[Union[str, MellonParam]] = default_node_config.get("inputs", [])
+ model_inputs_list: List[Union[str, MellonParam]] = default_node_config.get("model_inputs", [])
+ outputs_list: List[Union[str, MellonParam]] = default_node_config.get("outputs", [])
+
+ for required_input_name in blocks.required_inputs:
+ if required_input_name not in inputs_list:
+ inputs_list.append(
+ MellonParam(
+ name=required_input_name, label=required_input_name, type=required_input_name, display="input"
+ )
+ )
+
+ for component_spec in blocks.expected_components:
+ if component_spec.name not in model_inputs_list:
+ model_inputs_list.append(
+ MellonParam(
+ name=component_spec.name,
+ label=component_spec.name,
+ type="diffusers_auto_model",
+ display="input",
+ )
+ )
+
+ return cls(
+ inputs=inputs_list,
+ model_inputs=model_inputs_list,
+ outputs=outputs_list,
+ blocks_names=blocks_names,
+ node_type=node_type,
+ )
+
+
+# Minimal modular registry for Mellon node configs
+class ModularMellonNodeRegistry:
+ """Registry mapping (pipeline class, blocks_name) -> list of MellonNodeConfig."""
+
+ def __init__(self):
+ self._registry = {}
+ self._initialized = False
+
+ def register(self, pipeline_cls: type, node_params: Dict[str, MellonNodeConfig]):
+ if not self._initialized:
+ _initialize_registry(self)
+ self._registry[pipeline_cls] = node_params
+
+ def get(self, pipeline_cls: type) -> MellonNodeConfig:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry.get(pipeline_cls, None)
+
+ def get_all(self) -> Dict[type, Dict[str, MellonNodeConfig]]:
+ if not self._initialized:
+ _initialize_registry(self)
+ return self._registry
+
+
+def _register_preset_node_types(
+ pipeline_cls, params_map: Dict[str, Dict[str, Any]], registry: ModularMellonNodeRegistry
+):
+ """Register all node-type presets for a given pipeline class from a params map."""
+ node_configs = {}
+ for node_type, spec in params_map.items():
+ node_config = MellonNodeConfig(
+ inputs=spec.get("inputs", []),
+ model_inputs=spec.get("model_inputs", []),
+ outputs=spec.get("outputs", []),
+ blocks_names=spec.get("block_names", []),
+ node_type=node_type,
+ )
+ node_configs[node_type] = node_config
+ registry.register(pipeline_cls, node_configs)
+
+
+def _initialize_registry(registry: ModularMellonNodeRegistry):
+ """Initialize the registry and register all available pipeline configs."""
+ print("Initializing registry")
+
+ registry._initialized = True
+
+ try:
+ from .qwenimage.modular_pipeline import QwenImageModularPipeline
+ from .qwenimage.node_utils import QwenImage_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(QwenImageModularPipeline, QwenImage_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register QwenImageModularPipeline")
+
+ try:
+ from .stable_diffusion_xl.modular_pipeline import StableDiffusionXLModularPipeline
+ from .stable_diffusion_xl.node_utils import SDXL_NODE_TYPES_PARAMS_MAP
+
+ _register_preset_node_types(StableDiffusionXLModularPipeline, SDXL_NODE_TYPES_PARAMS_MAP, registry)
+ except Exception:
+ raise Exception("Failed to register StableDiffusionXLModularPipeline")
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
index 78226a49b1..307698245e 100644
--- a/src/diffusers/modular_pipelines/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -51,23 +51,16 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+# map regular pipeline to modular pipeline class name
MODULAR_PIPELINE_MAPPING = OrderedDict(
[
("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
("wan", "WanModularPipeline"),
("flux", "FluxModularPipeline"),
+ ("flux-kontext", "FluxKontextModularPipeline"),
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
- ]
-)
-
-MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
- [
- ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
- ("WanModularPipeline", "WanAutoBlocks"),
- ("FluxModularPipeline", "FluxAutoBlocks"),
- ("QwenImageModularPipeline", "QwenImageAutoBlocks"),
- ("QwenImageEditModularPipeline", "QwenImageEditAutoBlocks"),
+ ("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
]
)
@@ -137,8 +130,14 @@ class PipelineState:
Allow attribute access to intermediate values. If an attribute is not found in the object, look for it in the
intermediates dict.
"""
- if name in self.values:
- return self.values[name]
+ # Use object.__getattribute__ to avoid infinite recursion during deepcopy
+ try:
+ values = object.__getattribute__(self, "values")
+ except AttributeError:
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
+
+ if name in values:
+ return values[name]
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
def __repr__(self):
@@ -235,11 +234,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "modular_config.json"
@@ -310,20 +305,20 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"cache_dir",
"force_download",
"local_files_only",
+ "local_dir",
"proxies",
- "resume_download",
"revision",
"subfolder",
"token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
- config = cls.load_config(pretrained_model_name_or_path)
+ config = cls.load_config(pretrained_model_name_or_path, **hub_kwargs)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
- if not (has_remote_code and trust_remote_code):
+ if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
@@ -336,11 +331,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module_file=module_file,
class_name=class_name,
**hub_kwargs,
- **kwargs,
)
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
block_kwargs = {
- name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
+ name: kwargs.get(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
}
return block_cls(**block_kwargs)
@@ -423,7 +417,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
state.set(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
- # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
+ # if it is a kwargs type, e.g. "denoiser_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
@@ -534,11 +528,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -796,11 +786,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -1155,11 +1141,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
@@ -1442,11 +1424,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
Base class for all Modular pipelines.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
@@ -1454,6 +1432,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
config_name = "modular_model_index.json"
hf_device_map = None
+ default_blocks_name = None
# YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
def __init__(
@@ -1514,7 +1493,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
`_blocks_class_name` in the config dict
"""
if blocks is None:
- blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
+ blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
@@ -1656,7 +1635,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
- except EnvironmentError:
+ except EnvironmentError as e:
+ logger.debug(f"EnvironmentError: {e}")
blocks = None
cache_dir = kwargs.pop("cache_dir", None)
@@ -2150,8 +2130,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_load_kwargs[key] = value["default"]
try:
components_to_register[name] = spec.load(**component_load_kwargs)
- except Exception as e:
- logger.warning(f"Failed to create component '{name}': {e}")
+ except Exception:
+ logger.warning(
+ f"\nFailed to create component {name}:\n"
+ f"- Component spec: {spec}\n"
+ f"- load() called with kwargs: {component_load_kwargs}\n\n"
+ f"{traceback.format_exc()}"
+ )
# Register all components at once
self.register_components(**components_to_register)
@@ -2181,12 +2166,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
-
-
- If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
- the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
-
-
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
@@ -2521,6 +2502,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
if state is None:
state = PipelineState()
+ else:
+ state = deepcopy(state)
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py
index 5db860c788..f7ee1dd309 100644
--- a/src/diffusers/modular_pipelines/node_utils.py
+++ b/src/diffusers/modular_pipelines/node_utils.py
@@ -351,11 +351,7 @@ class ModularNode(ConfigMixin):
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "node_config.json"
diff --git a/src/diffusers/modular_pipelines/qwenimage/__init__.py b/src/diffusers/modular_pipelines/qwenimage/__init__.py
index 81cf515730..ae4ec4799f 100644
--- a/src/diffusers/modular_pipelines/qwenimage/__init__.py
+++ b/src/diffusers/modular_pipelines/qwenimage/__init__.py
@@ -29,13 +29,20 @@ else:
"EDIT_AUTO_BLOCKS",
"EDIT_BLOCKS",
"EDIT_INPAINT_BLOCKS",
+ "EDIT_PLUS_AUTO_BLOCKS",
+ "EDIT_PLUS_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
+ "QwenImageEditPlusAutoBlocks",
+ ]
+ _import_structure["modular_pipeline"] = [
+ "QwenImageEditModularPipeline",
+ "QwenImageEditPlusModularPipeline",
+ "QwenImageModularPipeline",
]
- _import_structure["modular_pipeline"] = ["QwenImageEditModularPipeline", "QwenImageModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -54,13 +61,20 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EDIT_AUTO_BLOCKS,
EDIT_BLOCKS,
EDIT_INPAINT_BLOCKS,
+ EDIT_PLUS_AUTO_BLOCKS,
+ EDIT_PLUS_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
TEXT2IMAGE_BLOCKS,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
+ QwenImageEditPlusAutoBlocks,
+ )
+ from .modular_pipeline import (
+ QwenImageEditModularPipeline,
+ QwenImageEditPlusModularPipeline,
+ QwenImageModularPipeline,
)
- from .modular_pipeline import QwenImageEditModularPipeline, QwenImageModularPipeline
else:
import sys
diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
index 738a1e5d15..fdec95dc50 100644
--- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
+++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py
@@ -203,7 +203,6 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
-
return components, state
@@ -571,15 +570,14 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
@property
def description(self) -> str:
- return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be place after prepare_latents step"
+ return "Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit. Should be placed after prepare_latents step"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
- InputParam(
- name="resized_image", required=True, type_hint=torch.Tensor, description="The resized image input"
- ),
+ InputParam(name="image_height", required=True),
+ InputParam(name="image_width", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
@@ -612,10 +610,6 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
block_state = self.get_block_state(state)
# for edit, image size can be different from the target size (height/width)
- image = (
- block_state.resized_image[0] if isinstance(block_state.resized_image, list) else block_state.resized_image
- )
- image_width, image_height = image.size
block_state.img_shapes = [
[
@@ -624,7 +618,11 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
),
- (1, image_height // components.vae_scale_factor // 2, image_width // components.vae_scale_factor // 2),
+ (
+ 1,
+ block_state.image_height // components.vae_scale_factor // 2,
+ block_state.image_width // components.vae_scale_factor // 2,
+ ),
]
] * block_state.batch_size
diff --git a/src/diffusers/modular_pipelines/qwenimage/denoise.py b/src/diffusers/modular_pipelines/qwenimage/denoise.py
index d0704ee6e0..49acd2dc02 100644
--- a/src/diffusers/modular_pipelines/qwenimage/denoise.py
+++ b/src/diffusers/modular_pipelines/qwenimage/denoise.py
@@ -238,19 +238,27 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
- guider_input_fields = {
- "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
- "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
- "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ guider_inputs = {
+ "encoder_hidden_states": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "encoder_hidden_states_mask": (
+ getattr(block_state, "prompt_embeds_mask", None),
+ getattr(block_state, "negative_prompt_embeds_mask", None),
+ ),
+ "txt_seq_lens": (
+ getattr(block_state, "txt_seq_lens", None),
+ getattr(block_state, "negative_txt_seq_lens", None),
+ ),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
- guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+ guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
- cond_kwargs = guider_state_batch.as_dict()
- cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
@@ -328,19 +336,27 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
- guider_input_fields = {
- "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"),
- "encoder_hidden_states_mask": ("prompt_embeds_mask", "negative_prompt_embeds_mask"),
- "txt_seq_lens": ("txt_seq_lens", "negative_txt_seq_lens"),
+ guider_inputs = {
+ "encoder_hidden_states": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "encoder_hidden_states_mask": (
+ getattr(block_state, "prompt_embeds_mask", None),
+ getattr(block_state, "negative_prompt_embeds_mask", None),
+ ),
+ "txt_seq_lens": (
+ getattr(block_state, "txt_seq_lens", None),
+ getattr(block_state, "negative_txt_seq_lens", None),
+ ),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
- guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+ guider_state = components.guider.prepare_inputs(guider_inputs)
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
- cond_kwargs = guider_state_batch.as_dict()
- cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
# YiYi TODO: add cache context
guider_state_batch.noise_pred = components.transformer(
diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py
index 280fa6a152..04fb3fdc94 100644
--- a/src/diffusers/modular_pipelines/qwenimage/encoders.py
+++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py
@@ -128,6 +128,61 @@ def get_qwen_prompt_embeds_edit(
return prompt_embeds, encoder_attention_mask
+def get_qwen_prompt_embeds_edit_plus(
+ text_encoder,
+ processor,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[Union[torch.Tensor, List[PIL.Image.Image], PIL.Image.Image]] = None,
+ prompt_template_encode: str = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ img_template_encode: str = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ prompt_template_encode_start_idx: int = 64,
+ device: Optional[torch.device] = None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_template_encode.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_template_encode.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = prompt_template_encode
+
+ drop_idx = prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+ outputs = text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = _extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(device=device)
+ return prompt_embeds, encoder_attention_mask
+
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -266,6 +321,83 @@ class QwenImageEditResizeDynamicStep(ModularPipelineBlocks):
return components, state
+class QwenImageEditPlusResizeDynamicStep(QwenImageEditResizeDynamicStep):
+ model_name = "qwenimage"
+
+ def __init__(
+ self,
+ input_name: str = "image",
+ output_name: str = "resized_image",
+ vae_image_output_name: str = "vae_image",
+ ):
+ """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
+
+ This block resizes an input image or a list input images and exposes the resized result under configurable
+ input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
+ "image", "control_image")
+
+ Args:
+ input_name (str, optional): Name of the image field to read from the
+ pipeline state. Defaults to "image".
+ output_name (str, optional): Name of the resized image field to write
+ back to the pipeline state. Defaults to "resized_image".
+ vae_image_output_name (str, optional): Name of the image field
+ to write back to the pipeline state. This is used by the VAE encoder step later on. QwenImage Edit Plus
+ processes the input image(s) differently for the VL and the VAE.
+ """
+ if not isinstance(input_name, str) or not isinstance(output_name, str):
+ raise ValueError(
+ f"input_name and output_name must be strings but are {type(input_name)} and {type(output_name)}"
+ )
+ self.condition_image_size = 384 * 384
+ self._image_input_name = input_name
+ self._resized_image_output_name = output_name
+ self._vae_image_output_name = vae_image_output_name
+ super().__init__()
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return super().intermediate_outputs + [
+ OutputParam(
+ name=self._vae_image_output_name,
+ type_hint=List[PIL.Image.Image],
+ description="The images to be processed which will be further used by the VAE encoder.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ images = getattr(block_state, self._image_input_name)
+
+ if not is_valid_image_imagelist(images):
+ raise ValueError(f"Images must be image or list of images but are {type(images)}")
+
+ if (
+ not isinstance(images, torch.Tensor)
+ and isinstance(images, PIL.Image.Image)
+ and not isinstance(images, list)
+ ):
+ images = [images]
+
+ # TODO (sayakpaul): revisit this when the inputs are `torch.Tensor`s
+ condition_images = []
+ vae_images = []
+ for img in images:
+ image_width, image_height = img.size
+ condition_width, condition_height, _ = calculate_dimensions(
+ self.condition_image_size, image_width / image_height
+ )
+ condition_images.append(components.image_resize_processor.resize(img, condition_height, condition_width))
+ vae_images.append(img)
+
+ setattr(block_state, self._resized_image_output_name, condition_images)
+ setattr(block_state, self._vae_image_output_name, vae_images)
+ self.set_block_state(state, block_state)
+ return components, state
+
+
class QwenImageTextEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -496,7 +628,7 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
)
if components.requires_unconditional_embeds:
- negative_prompt = block_state.negative_prompt or ""
+ negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
components.text_encoder,
components.processor,
@@ -511,6 +643,61 @@ class QwenImageEditTextEncoderStep(ModularPipelineBlocks):
return components, state
+class QwenImageEditPlusTextEncoderStep(QwenImageEditTextEncoderStep):
+ model_name = "qwenimage"
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec(
+ name="prompt_template_encode",
+ default="<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
+ ),
+ ConfigSpec(
+ name="img_template_encode",
+ default="Picture {}: <|vision_start|><|image_pad|><|vision_end|>",
+ ),
+ ConfigSpec(name="prompt_template_encode_start_idx", default=64),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ self.check_inputs(block_state.prompt, block_state.negative_prompt)
+
+ device = components._execution_device
+
+ block_state.prompt_embeds, block_state.prompt_embeds_mask = get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=block_state.prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+
+ if components.requires_unconditional_embeds:
+ negative_prompt = block_state.negative_prompt or " "
+ block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
+ get_qwen_prompt_embeds_edit_plus(
+ components.text_encoder,
+ components.processor,
+ prompt=negative_prompt,
+ image=block_state.resized_image,
+ prompt_template_encode=components.config.prompt_template_encode,
+ img_template_encode=components.config.img_template_encode,
+ prompt_template_encode_start_idx=components.config.prompt_template_encode_start_idx,
+ device=device,
+ )
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -612,12 +799,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
@property
def inputs(self) -> List[InputParam]:
- return [
- InputParam("resized_image"),
- InputParam("image"),
- InputParam("height"),
- InputParam("width"),
- ]
+ return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")]
@property
def intermediate_outputs(self) -> List[OutputParam]:
@@ -661,6 +843,47 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
return components, state
+class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
+ model_name = "qwenimage-edit-plus"
+ vae_image_size = 1024 * 1024
+
+ @property
+ def description(self) -> str:
+ return "Image Preprocess step for QwenImage Edit Plus. Unlike QwenImage Edit, QwenImage Edit Plus doesn't use the same resized image for further preprocessing."
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")]
+
+ @torch.no_grad()
+ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
+ block_state = self.get_block_state(state)
+
+ if block_state.vae_image is None and block_state.image is None:
+ raise ValueError("`vae_image` and `image` cannot be None at the same time")
+
+ if block_state.vae_image is None:
+ image = block_state.image
+ self.check_inputs(
+ height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor
+ )
+ height = block_state.height or components.default_height
+ width = block_state.width or components.default_width
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+ else:
+ width, height = block_state.vae_image[0].size
+ image = block_state.vae_image
+
+ block_state.processed_image = components.image_processor.preprocess(
+ image=image, height=height, width=width
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -738,7 +961,6 @@ class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
dtype=dtype,
latent_channels=components.num_channels_latents,
)
-
setattr(block_state, self._image_latents_output_name, image_latents)
self.set_block_state(state, block_state)
diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py
index 2b787c8238..2b229c040b 100644
--- a/src/diffusers/modular_pipelines/qwenimage/inputs.py
+++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py
@@ -307,6 +307,13 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
return inputs
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
+ OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
+ ]
+
@property
def expected_components(self) -> List[ComponentSpec]:
return [
@@ -327,6 +334,11 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
block_state.height = block_state.height or height
block_state.width = block_state.width or width
+ if not hasattr(block_state, "image_height"):
+ block_state.image_height = height
+ if not hasattr(block_state, "image_width"):
+ block_state.image_width = width
+
# 2. Patchify the image latent tensor
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
index a01c742fcf..83bfcb3da4 100644
--- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py
@@ -37,6 +37,9 @@ from .denoise import (
)
from .encoders import (
QwenImageControlNetVaeEncoderStep,
+ QwenImageEditPlusProcessImagesInputStep,
+ QwenImageEditPlusResizeDynamicStep,
+ QwenImageEditPlusTextEncoderStep,
QwenImageEditResizeDynamicStep,
QwenImageEditTextEncoderStep,
QwenImageInpaintProcessImagesInputStep,
@@ -511,17 +514,42 @@ class QwenImageAutoDecodeStep(AutoPipelineBlocks):
)
+class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = [
+ QwenImageAutoInputStep,
+ QwenImageOptionalControlNetInputStep,
+ QwenImageAutoBeforeDenoiseStep,
+ QwenImageOptionalControlNetBeforeDenoiseStep,
+ QwenImageAutoDenoiseStep,
+ ]
+ block_names = ["input", "controlnet_input", "before_denoise", "controlnet_before_denoise", "denoise", "decode"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `QwenImageAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageOptionalControlNetBeforeDenoiseStep` (controlnet_before_denoise) prepares the controlnet input for the denoising step.\n"
+ + " - `QwenImageAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ + " - `QwenImageAutoDecodeStep` (decode) decodes the latents into images.\n\n"
+ + "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ + " - for image-to-image generation, you need to provide `image_latents`\n"
+ + " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ + " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ + " - for text-to-image generation, all you need to provide is prompt embeddings"
+ )
+
+
## 1.10 QwenImage/auto block & presets
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
- ("input", QwenImageAutoInputStep()),
- ("controlnet_input", QwenImageOptionalControlNetInputStep()),
- ("before_denoise", QwenImageAutoBeforeDenoiseStep()),
- ("controlnet_before_denoise", QwenImageOptionalControlNetBeforeDenoiseStep()),
- ("denoise", QwenImageAutoDenoiseStep()),
+ ("denoise", QwenImageCoreDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
@@ -699,7 +727,7 @@ class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
class QwenImageEditAutoInputStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintInputStep, QwenImageEditInputStep]
block_names = ["edit_inpaint", "edit"]
- block_trigger_inputs = ["processed_mask_image", "image"]
+ block_trigger_inputs = ["processed_mask_image", "image_latents"]
@property
def description(self):
@@ -800,13 +828,34 @@ class QwenImageEditAutoDenoiseStep(AutoPipelineBlocks):
## 2.7 QwenImage-Edit/auto blocks & presets
+
+class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) and edit inpainting workflow for QwenImage Edit:\n"
+ + " - When `processed_mask_image` is provided, it will be used for edit inpainting task.\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
EDIT_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
- ("input", QwenImageEditAutoInputStep()),
- ("before_denoise", QwenImageEditAutoBeforeDenoiseStep()),
- ("denoise", QwenImageEditAutoDenoiseStep()),
+ ("denoise", QwenImageEditCoreDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
@@ -826,7 +875,151 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
)
-# 3. all block presets supported in QwenImage & QwenImage-Edit
+#################### QwenImage Edit Plus #####################
+
+# 3. QwenImage-Edit Plus
+
+## 3.1 QwenImage-Edit Plus / edit
+
+#### QwenImage-Edit Plus vl encoder: take both image and text prompts
+QwenImageEditPlusVLEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()),
+ ("encode", QwenImageEditPlusTextEncoderStep()),
+ ]
+)
+
+
+class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVLEncoderBlocks.values()
+ block_names = QwenImageEditPlusVLEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "QwenImage-Edit Plus VL encoder step that encode the image an text prompts together."
+
+
+#### QwenImage-Edit Plus vae encoder
+QwenImageEditPlusVaeEncoderBlocks = InsertableDict(
+ [
+ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step
+ ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image
+ ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents
+ ]
+)
+
+
+class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
+ model_name = "qwenimage"
+ block_classes = QwenImageEditPlusVaeEncoderBlocks.values()
+ block_names = QwenImageEditPlusVaeEncoderBlocks.keys()
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that encode the image inputs into their latent representations."
+
+
+#### QwenImage Edit Plus presets
+EDIT_PLUS_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
+ ("input", QwenImageEditInputStep()),
+ ("prepare_latents", QwenImagePrepareLatentsStep()),
+ ("set_timesteps", QwenImageSetTimestepsStep()),
+ ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()),
+ ("denoise", QwenImageEditDenoiseStep()),
+ ("decode", QwenImageDecodeStep()),
+ ]
+)
+
+
+# auto before_denoise step for edit tasks
+class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [QwenImageEditBeforeDenoiseStep]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n"
+ + "This is an auto pipeline block that works for edit (img2img) task.\n"
+ + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n"
+ + " - if `image_latents` is not provided, step will be skipped."
+ )
+
+
+## 3.2 QwenImage-Edit Plus/auto encoders
+
+
+class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [
+ QwenImageEditPlusVaeEncoderStep,
+ ]
+ block_names = ["edit"]
+ block_trigger_inputs = ["image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations. \n"
+ " This is an auto pipeline block that works for edit task.\n"
+ + " - `QwenImageEditPlusVaeEncoderStep` (edit) is used when `image` is provided.\n"
+ + " - if `image` is not provided, step will be skipped."
+ )
+
+
+## 3.3 QwenImage-Edit/auto blocks & presets
+
+
+class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = [
+ QwenImageEditAutoInputStep,
+ QwenImageEditPlusAutoBeforeDenoiseStep,
+ QwenImageEditAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `QwenImageEditAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `QwenImageEditPlusAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `QwenImageEditAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support edit (img2img) workflow for QwenImage Edit Plus:\n"
+ + " - When `image_latents` is provided, it will be used for edit (img2img) task.\n"
+ )
+
+
+EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", QwenImageEditPlusVLEncoderStep()),
+ ("vae_encoder", QwenImageEditPlusAutoVaeEncoderStep()),
+ ("denoise", QwenImageEditPlusCoreDenoiseStep()),
+ ("decode", QwenImageAutoDecodeStep()),
+ ]
+)
+
+
+class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
+ model_name = "qwenimage-edit-plus"
+ block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
+ block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for edit (img2img) and edit tasks using QwenImage-Edit Plus.\n"
+ + "- for edit (img2img) generation, you need to provide `image`\n"
+ )
+
+
+# 3. all block presets supported in QwenImage, QwenImage-Edit, QwenImage-Edit Plus
ALL_BLOCKS = {
@@ -834,8 +1027,10 @@ ALL_BLOCKS = {
"img2img": IMAGE2IMAGE_BLOCKS,
"edit": EDIT_BLOCKS,
"edit_inpaint": EDIT_INPAINT_BLOCKS,
+ "edit_plus": EDIT_PLUS_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"auto": AUTO_BLOCKS,
"edit_auto": EDIT_AUTO_BLOCKS,
+ "edit_plus_auto": EDIT_PLUS_AUTO_BLOCKS,
}
diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
index fe9757f41b..d9e30864f6 100644
--- a/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/qwenimage/modular_pipeline.py
@@ -97,13 +97,11 @@ class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "QwenImageAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
@@ -151,13 +149,11 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage-Edit.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "QwenImageEditAutoBlocks"
+
# YiYi TODO: qwen edit should not provide default height/width, should be derived from the resized input image (after adjustment) produced by the resize step.
@property
def default_height(self):
@@ -200,3 +196,13 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
+
+
+class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
+ """
+ A ModularPipeline for QwenImage-Edit Plus.
+
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
+ """
+
+ default_blocks_name = "QwenImageEditPlusAutoBlocks"
diff --git a/src/diffusers/modular_pipelines/qwenimage/node_utils.py b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
new file mode 100644
index 0000000000..3230ece68a
--- /dev/null
+++ b/src/diffusers/modular_pipelines/qwenimage/node_utils.py
@@ -0,0 +1,95 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# mellon nodes
+QwenImage_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ "vae",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": ["controlnet_vae_encoder"],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ "controlnet",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
index fefa622f1a..70cbf0c1c7 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
@@ -262,37 +262,37 @@ class StableDiffusionXLInputStep(ModularPipelineBlocks):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative pooled text embeddings used to guide the image generation",
),
OutputParam(
"ip_adapter_embeds",
type_hint=List[torch.Tensor],
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="image embeddings for IP-Adapter",
),
OutputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative image embeddings for IP-Adapter",
),
]
@@ -1120,13 +1120,13 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineB
OutputParam(
"add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The time ids to condition the denoising process",
),
OutputParam(
"negative_add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The negative time ids to condition the denoising process",
),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
@@ -1331,13 +1331,13 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
OutputParam(
"add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The time ids to condition the denoising process",
),
OutputParam(
"negative_add_time_ids",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="The negative time ids to condition the denoising process",
),
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
index a2e1420595..862315e591 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
@@ -115,7 +115,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
def check_inputs(components, block_state):
num_channels_unet = components.num_channels_unet
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
if block_state.mask is None or block_state.masked_image_latents is None:
raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
num_channels_latents = block_state.latents.shape[1]
@@ -183,14 +183,14 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
- "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@@ -201,27 +201,41 @@ class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
- guider_input_fields = {
- "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
- "time_ids": ("add_time_ids", "negative_add_time_ids"),
- "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
- "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
+ guider_inputs = {
+ "prompt_embeds": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "time_ids": (
+ getattr(block_state, "add_time_ids", None),
+ getattr(block_state, "negative_add_time_ids", None),
+ ),
+ "text_embeds": (
+ getattr(block_state, "pooled_prompt_embeds", None),
+ getattr(block_state, "negative_pooled_prompt_embeds", None),
+ ),
+ "image_embeds": (
+ getattr(block_state, "ip_adapter_embeds", None),
+ getattr(block_state, "negative_ip_adapter_embeds", None),
+ ),
}
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
- # Prepare mini‐batches according to guidance method and `guider_input_fields`
- # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
- # e.g. for CFG, we prepare two batches: one for uncond, one for cond
- # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
- # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
- guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.unet)
- cond_kwargs = guider_state_batch.as_dict()
- cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
@@ -307,14 +321,14 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds, "
"add_time_ids/negative_add_time_ids, "
"pooled_prompt_embeds/negative_pooled_prompt_embeds, "
"and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
- "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
InputParam(
@@ -344,11 +358,23 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
- guider_input_fields = {
- "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
- "time_ids": ("add_time_ids", "negative_add_time_ids"),
- "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
- "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
+ guider_inputs = {
+ "prompt_embeds": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
+ "time_ids": (
+ getattr(block_state, "add_time_ids", None),
+ getattr(block_state, "negative_add_time_ids", None),
+ ),
+ "text_embeds": (
+ getattr(block_state, "pooled_prompt_embeds", None),
+ getattr(block_state, "negative_pooled_prompt_embeds", None),
+ ),
+ "image_embeds": (
+ getattr(block_state, "ip_adapter_embeds", None),
+ getattr(block_state, "negative_ip_adapter_embeds", None),
+ ),
}
# cond_scale for the timestep (controlnet input)
@@ -369,12 +395,15 @@ class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
# guided denoiser step
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
- # Prepare mini‐batches according to guidance method and `guider_input_fields`
- # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
- # e.g. for CFG, we prepare two batches: one for uncond, one for cond
- # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
- # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
- guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
index 1e8921d363..90b254b6f5 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
@@ -258,25 +258,25 @@ class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
OutputParam(
"pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="pooled text embeddings used to guide the image generation",
),
OutputParam(
"negative_pooled_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative pooled text embeddings used to guide the image generation",
),
]
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
index c9033856bc..68b5e33755 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
@@ -82,19 +82,17 @@ class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
# before_denoise: text2img
class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -104,19 +102,17 @@ class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: img2img
class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -126,19 +122,17 @@ class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
# before_denoise: inpainting
class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = [
- StableDiffusionXLInputStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
]
- block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+ block_names = ["set_timesteps", "prepare_latents", "prepare_add_cond"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ "This is a sequential pipeline blocks:\n"
- + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
@@ -255,25 +249,48 @@ class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
)
+class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLAutoBeforeDenoiseStep,
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoDenoiseStep,
+ ]
+ block_names = ["input", "before_denoise", "controlnet_input", "denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Core step that performs the denoising process. \n"
+ + " - `StableDiffusionXLInputStep` (input) standardizes the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ + " - `StableDiffusionXLAutoControlNetInputStep` (controlnet_input) prepares the controlnet input.\n"
+ + " - `StableDiffusionXLAutoDenoiseStep` (denoise) iteratively denoises the latents.\n\n"
+ + "This step support text-to-image, image-to-image, inpainting, with or without controlnet/controlnet_union/ip_adapter for Stable Diffusion XL:\n"
+ + "- for image-to-image generation, you need to provide `image_latents`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image_latents`\n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ + "- to run the ip_adapter workflow, you need to load ip_adapter into your unet and provide `ip_adapter_embeds`\n"
+ + "- for text-to-image generation, all you need to provide is prompt embeddings\n"
+ )
+
+
# ip-adapter, controlnet, text2img, img2img, inpainting
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
- StableDiffusionXLAutoBeforeDenoiseStep,
- StableDiffusionXLAutoControlNetInputStep,
- StableDiffusionXLAutoDenoiseStep,
+ StableDiffusionXLCoreDenoiseStep,
StableDiffusionXLAutoDecodeStep,
]
block_names = [
"text_encoder",
"ip_adapter",
- "image_encoder",
- "before_denoise",
- "controlnet_input",
+ "vae_encoder",
"denoise",
- "decoder",
+ "decode",
]
@property
@@ -321,7 +338,7 @@ TEXT2IMAGE_BLOCKS = InsertableDict(
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
- ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("vae_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
@@ -334,7 +351,7 @@ IMAGE2IMAGE_BLOCKS = InsertableDict(
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
- ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
+ ("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
@@ -361,10 +378,8 @@ AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
- ("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
- ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
- ("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
- ("denoise", StableDiffusionXLAutoDenoiseStep),
+ ("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
+ ("denoise", StableDiffusionXLCoreDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep),
]
)
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
index e84f5cad1a..f2a4c96073 100644
--- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
@@ -47,13 +47,11 @@ class StableDiffusionXLModularPipeline(
"""
A ModularPipeline for Stable Diffusion XL.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "StableDiffusionXLAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_size * self.vae_scale_factor
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
new file mode 100644
index 0000000000..3e788bf947
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/node_utils.py
@@ -0,0 +1,99 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+SDXL_NODE_TYPES_PARAMS_MAP = {
+ "controlnet": {
+ "inputs": [
+ "control_image",
+ "controlnet_conditioning_scale",
+ "control_guidance_start",
+ "control_guidance_end",
+ "height",
+ "width",
+ ],
+ "model_inputs": [
+ "controlnet",
+ ],
+ "outputs": [
+ "controlnet_out",
+ ],
+ "block_names": [None],
+ },
+ "denoise": {
+ "inputs": [
+ "embeddings",
+ "width",
+ "height",
+ "seed",
+ "num_inference_steps",
+ "guidance_scale",
+ "image_latents",
+ "strength",
+ # custom adapters coming in as inputs
+ "controlnet",
+ # ip_adapter is optional and custom; include if available
+ "ip_adapter",
+ ],
+ "model_inputs": [
+ "unet",
+ "guider",
+ "scheduler",
+ ],
+ "outputs": [
+ "latents",
+ "latents_preview",
+ ],
+ "block_names": ["denoise"],
+ },
+ "vae_encoder": {
+ "inputs": [
+ "image",
+ "width",
+ "height",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "image_latents",
+ ],
+ "block_names": ["vae_encoder"],
+ },
+ "text_encoder": {
+ "inputs": [
+ "prompt",
+ "negative_prompt",
+ ],
+ "model_inputs": [
+ "text_encoders",
+ ],
+ "outputs": [
+ "embeddings",
+ ],
+ "block_names": ["text_encoder"],
+ },
+ "decoder": {
+ "inputs": [
+ "latents",
+ ],
+ "model_inputs": [
+ "vae",
+ ],
+ "outputs": [
+ "images",
+ ],
+ "block_names": ["decode"],
+ },
+}
diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py
index 2b9889f877..d48f678edd 100644
--- a/src/diffusers/modular_pipelines/wan/before_denoise.py
+++ b/src/diffusers/modular_pipelines/wan/before_denoise.py
@@ -146,13 +146,13 @@ class WanInputStep(ModularPipelineBlocks):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
),
]
diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py
index 5f578609c2..4f3ca80acc 100644
--- a/src/diffusers/modular_pipelines/wan/denoise.py
+++ b/src/diffusers/modular_pipelines/wan/denoise.py
@@ -79,11 +79,11 @@ class WanLoopDenoiser(ModularPipelineBlocks):
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description=(
"All conditional model inputs that need to be prepared with guider. "
"It should contain prompt_embeds/negative_prompt_embeds. "
- "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ "Please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
),
),
]
@@ -94,25 +94,30 @@ class WanLoopDenoiser(ModularPipelineBlocks):
) -> PipelineState:
# Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
# to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
- guider_input_fields = {
- "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
+ guider_inputs = {
+ "prompt_embeds": (
+ getattr(block_state, "prompt_embeds", None),
+ getattr(block_state, "negative_prompt_embeds", None),
+ ),
}
transformer_dtype = components.transformer.dtype
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
- # Prepare mini‐batches according to guidance method and `guider_input_fields`
- # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
- # e.g. for CFG, we prepare two batches: one for uncond, one for cond
- # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
- # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
- guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = components.guider.prepare_inputs(guider_inputs)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
- cond_kwargs = guider_state_batch.as_dict()
- cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()}
prompt_embeds = cond_kwargs.pop("prompt_embeds")
# Predict the noise residual
diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py
index a0bf76b99b..cb2fc24238 100644
--- a/src/diffusers/modular_pipelines/wan/encoders.py
+++ b/src/diffusers/modular_pipelines/wan/encoders.py
@@ -89,13 +89,13 @@ class WanTextEncoderStep(ModularPipelineBlocks):
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
- kwargs_type="guider_input_fields",
+ kwargs_type="denoiser_input_fields",
description="negative text embeddings used to guide the image generation",
),
]
diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
index 4d86e0d08e..e4adf3d151 100644
--- a/src/diffusers/modular_pipelines/wan/modular_pipeline.py
+++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
@@ -30,13 +30,11 @@ class WanModularPipeline(
"""
A ModularPipeline for Wan.
-
-
- This is an experimental feature and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
+ default_blocks_name = "WanAutoBlocks"
+
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md
index 363caffe20..6f9ab7b291 100644
--- a/src/diffusers/pipelines/README.md
+++ b/src/diffusers/pipelines/README.md
@@ -159,7 +159,7 @@ init_image = download_image(img_url).resize((512, 512))
mask_image = download_image(mask_url).resize((512, 512))
pipe = StableDiffusionInpaintPipeline.from_pretrained(
- "runwayml/stable-diffusion-inpainting",
+ "stable-diffusion-v1-5/stable-diffusion-inpainting",
torch_dtype=torch.float16,
)
pipe = pipe.to("cuda")
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 8ed07a72e3..87d953845e 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -128,6 +128,7 @@ else:
"AnimateDiffVideoToVideoControlNetPipeline",
]
_import_structure["bria"] = ["BriaPipeline"]
+ _import_structure["bria_fibo"] = ["BriaFiboPipeline"]
_import_structure["flux"] = [
"FluxControlPipeline",
"FluxControlInpaintPipeline",
@@ -144,6 +145,7 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
+ _import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -240,6 +242,7 @@ else:
"HunyuanVideoImageToVideoPipeline",
"HunyuanVideoFramepackPipeline",
]
+ _import_structure["hunyuan_image"] = ["HunyuanImagePipeline", "HunyuanImageRefinerPipeline"]
_import_structure["kandinsky"] = [
"KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline",
@@ -285,6 +288,7 @@ else:
]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
+ _import_structure["lucy"] = ["LucyEditPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -304,6 +308,7 @@ else:
"SanaSprintPipeline",
"SanaControlNetPipeline",
"SanaSprintImg2ImgPipeline",
+ "SanaVideoPipeline",
]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
@@ -381,6 +386,7 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
+ _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"]
_import_structure["skyreels_v2"] = [
"SkyReelsV2DiffusionForcingPipeline",
"SkyReelsV2DiffusionForcingImageToVideoPipeline",
@@ -393,6 +399,7 @@ else:
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImageEditPipeline",
+ "QwenImageEditPlusPipeline",
"QwenImageEditInpaintPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
@@ -557,6 +564,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
+ from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,
@@ -636,6 +644,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ReduxImageEncoder,
)
from .hidream_image import HiDreamImagePipeline
+ from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
HunyuanSkyreelsImageToVideoPipeline,
HunyuanVideoFramepackPipeline,
@@ -669,6 +678,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
+ from .kandinsky5 import Kandinsky5T2VPipeline
from .latent_consistency_models import (
LatentConsistencyModelImg2ImgPipeline,
LatentConsistencyModelPipeline,
@@ -682,6 +692,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
+ from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
from .marigold import (
@@ -714,16 +725,24 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
+ from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
)
- from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
+ from .sana import (
+ SanaControlNetPipeline,
+ SanaPipeline,
+ SanaSprintImg2ImgPipeline,
+ SanaSprintPipeline,
+ SanaVideoPipeline,
+ )
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
diff --git a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
index 546ae9239a..b6b40cd6e6 100644
--- a/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
+++ b/src/diffusers/pipelines/audioldm2/modeling_audioldm2.py
@@ -17,7 +17,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 880984eeb8..8a32d4c367 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -95,6 +95,7 @@ from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
QwenImageEditPipeline,
+ QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
@@ -186,6 +187,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux-kontext", FluxKontextPipeline),
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
+ ("qwenimage-edit-plus", QwenImageEditPlusPipeline),
]
)
@@ -407,12 +409,8 @@ class AutoPipelineForText2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -702,12 +700,8 @@ class AutoPipelineForImage2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -1012,12 +1006,8 @@ class AutoPipelineForInpainting(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
index 928698e442..b061ac2636 100644
--- a/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
+++ b/src/diffusers/pipelines/blip_diffusion/modeling_blip2.py
@@ -14,7 +14,6 @@
from typing import Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from transformers import BertTokenizer
from transformers.activations import QuickGELUActivation as QuickGELU
diff --git a/src/diffusers/pipelines/bria_fibo/__init__.py b/src/diffusers/pipelines/bria_fibo/__init__.py
new file mode 100644
index 0000000000..206a463b39
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_bria_fibo"] = ["BriaFiboPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_bria_fibo import BriaFiboPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
new file mode 100644
index 0000000000..85d29029e6
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py
@@ -0,0 +1,838 @@
+# Copyright (c) Bria.ai. All rights reserved.
+#
+# This file is licensed under the Creative Commons Attribution-NonCommercial 4.0 International Public License (CC-BY-NC-4.0).
+# You may obtain a copy of the license at https://creativecommons.org/licenses/by-nc/4.0/
+#
+# You are free to share and adapt this material for non-commercial purposes provided you give appropriate credit,
+# indicate if changes were made, and do not use the material for commercial purposes.
+#
+# See the license for further details.
+
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3ForCausalLM
+
+from ...image_processor import VaeImageProcessor
+from ...loaders import FluxLoraLoaderMixin
+from ...models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan
+from ...models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from ...pipelines.bria_fibo.pipeline_output import BriaFiboPipelineOutput
+from ...pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
+from ...pipelines.pipeline_utils import DiffusionPipeline
+from ...schedulers import FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Example:
+ ```python
+ import torch
+ from diffusers import BriaFiboPipeline
+ from diffusers.modular_pipelines import ModularPipeline
+
+ torch.set_grad_enabled(False)
+ vlm_pipe = ModularPipeline.from_pretrained("briaai/FIBO-VLM-prompt-to-JSON", trust_remote_code=True)
+
+ pipe = BriaFiboPipeline.from_pretrained(
+ "briaai/FIBO",
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ )
+ pipe.enable_model_cpu_offload()
+
+ with torch.inference_mode():
+ # 1. Create a prompt to generate an initial image
+ output = vlm_pipe(prompt="a beautiful dog")
+ json_prompt_generate = output.values["json_prompt"]
+
+ # Generate the image from the structured json prompt
+ results_generate = pipe(prompt=json_prompt_generate, num_inference_steps=50, guidance_scale=5)
+ results_generate.images[0].save("image_generate.png")
+ ```
+"""
+
+
+class BriaFiboPipeline(DiffusionPipeline):
+ r"""
+ Args:
+ transformer (`BriaFiboTransformer2DModel`):
+ The transformer model for 2D diffusion modeling.
+ scheduler (`FlowMatchEulerDiscreteScheduler` or `KarrasDiffusionSchedulers`):
+ Scheduler to be used with `transformer` to denoise the encoded latents.
+ vae (`AutoencoderKLWan`):
+ Variational Auto-Encoder for encoding and decoding images to and from latent representations.
+ text_encoder (`SmolLM3ForCausalLM`):
+ Text encoder for processing input prompts.
+ tokenizer (`AutoTokenizer`):
+ Tokenizer used for processing the input text prompts for the text_encoder.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ transformer: BriaFiboTransformer2DModel,
+ scheduler: Union[FlowMatchEulerDiscreteScheduler, KarrasDiffusionSchedulers],
+ vae: AutoencoderKLWan,
+ text_encoder: SmolLM3ForCausalLM,
+ tokenizer: AutoTokenizer,
+ ):
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor = 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.default_sample_size = 64
+
+ def get_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 2048,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if not prompt:
+ raise ValueError("`prompt` must be a non-empty string or list of strings.")
+
+ batch_size = len(prompt)
+ bot_token_id = 128000
+
+ text_encoder_device = device if device is not None else torch.device("cpu")
+ if not isinstance(text_encoder_device, torch.device):
+ text_encoder_device = torch.device(text_encoder_device)
+
+ if all(p == "" for p in prompt):
+ input_ids = torch.full((batch_size, 1), bot_token_id, dtype=torch.long, device=text_encoder_device)
+ attention_mask = torch.ones_like(input_ids)
+ else:
+ tokenized = self.tokenizer(
+ prompt,
+ padding="longest",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ input_ids = tokenized.input_ids.to(text_encoder_device)
+ attention_mask = tokenized.attention_mask.to(text_encoder_device)
+
+ if any(p == "" for p in prompt):
+ empty_rows = torch.tensor([p == "" for p in prompt], dtype=torch.bool, device=text_encoder_device)
+ input_ids[empty_rows] = bot_token_id
+ attention_mask[empty_rows] = 1
+
+ encoder_outputs = self.text_encoder(
+ input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )
+ hidden_states = encoder_outputs.hidden_states
+
+ prompt_embeds = torch.cat([hidden_states[-1], hidden_states[-2]], dim=-1)
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dtype)
+
+ prompt_embeds = prompt_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ hidden_states = tuple(
+ layer.repeat_interleave(num_images_per_prompt, dim=0).to(device=device) for layer in hidden_states
+ )
+ attention_mask = attention_mask.repeat_interleave(num_images_per_prompt, dim=0).to(device=device)
+
+ return prompt_embeds, hidden_states, attention_mask
+
+ @staticmethod
+ def pad_embedding(prompt_embeds, max_tokens, attention_mask=None):
+ # Pad embeddings to `max_tokens` while preserving the mask of real tokens.
+ batch_size, seq_len, dim = prompt_embeds.shape
+
+ if attention_mask is None:
+ attention_mask = torch.ones((batch_size, seq_len), dtype=prompt_embeds.dtype, device=prompt_embeds.device)
+ else:
+ attention_mask = attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if max_tokens < seq_len:
+ raise ValueError("`max_tokens` must be greater or equal to the current sequence length.")
+
+ if max_tokens > seq_len:
+ pad_length = max_tokens - seq_len
+ padding = torch.zeros(
+ (batch_size, pad_length, dim), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ prompt_embeds = torch.cat([prompt_embeds, padding], dim=1)
+
+ mask_padding = torch.zeros(
+ (batch_size, pad_length), dtype=prompt_embeds.dtype, device=prompt_embeds.device
+ )
+ attention_mask = torch.cat([attention_mask, mask_padding], dim=1)
+
+ return prompt_embeds, attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 3000,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ guidance_scale (`float`):
+ Guidance scale for classifier free guidance.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ prompt_attention_mask = None
+ negative_prompt_attention_mask = None
+ if prompt_embeds is None:
+ prompt_embeds, prompt_layers, prompt_attention_mask = self.get_prompt_embeds(
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ prompt_embeds = prompt_embeds.to(dtype=self.transformer.dtype)
+ prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in prompt_layers]
+
+ if guidance_scale > 1:
+ if isinstance(negative_prompt, list) and negative_prompt[0] is None:
+ negative_prompt = ""
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds, negative_prompt_layers, negative_prompt_attention_mask = self.get_prompt_embeds(
+ prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.transformer.dtype)
+ negative_prompt_layers = [tensor.to(dtype=self.transformer.dtype) for tensor in negative_prompt_layers]
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ # Pad to longest
+ if prompt_attention_mask is not None:
+ prompt_attention_mask = prompt_attention_mask.to(device=prompt_embeds.device, dtype=prompt_embeds.dtype)
+
+ if negative_prompt_embeds is not None:
+ if negative_prompt_attention_mask is not None:
+ negative_prompt_attention_mask = negative_prompt_attention_mask.to(
+ device=negative_prompt_embeds.device, dtype=negative_prompt_embeds.dtype
+ )
+ max_tokens = max(negative_prompt_embeds.shape[1], prompt_embeds.shape[1])
+
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in prompt_layers]
+
+ negative_prompt_embeds, negative_prompt_attention_mask = self.pad_embedding(
+ negative_prompt_embeds, max_tokens, attention_mask=negative_prompt_attention_mask
+ )
+ negative_prompt_layers = [self.pad_embedding(layer, max_tokens)[0] for layer in negative_prompt_layers]
+ else:
+ max_tokens = prompt_embeds.shape[1]
+ prompt_embeds, prompt_attention_mask = self.pad_embedding(
+ prompt_embeds, max_tokens, attention_mask=prompt_attention_mask
+ )
+ negative_prompt_layers = None
+
+ dtype = self.text_encoder.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[0], max_tokens, 3).to(device=device, dtype=dtype)
+
+ return (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ )
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @staticmethod
+ # Based on diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ def _unpack_latents_no_patch(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ height = height // vae_scale_factor
+ width = width // vae_scale_factor
+
+ latents = latents.view(batch_size, height, width, channels)
+ latents = latents.permute(0, 3, 1, 2)
+
+ return latents
+
+ @staticmethod
+ def _pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.permute(0, 2, 3, 1)
+ latents = latents.reshape(batch_size, height * width, num_channels_latents)
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ do_patching=False,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ if do_patching:
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ else:
+ latents = self._pack_latents_no_patch(latents, batch_size, num_channels_latents, height, width)
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
+
+ return latents, latent_image_ids
+
+ @staticmethod
+ def _prepare_attention_mask(attention_mask):
+ attention_matrix = torch.einsum("bi,bj->bij", attention_mask, attention_mask)
+
+ # convert to 0 - keep, -inf ignore
+ attention_matrix = torch.where(
+ attention_matrix == 1, 0.0, -torch.inf
+ ) # Apply -inf to ignored tokens for nulling softmax score
+ return attention_matrix
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 30,
+ timesteps: List[int] = None,
+ guidance_scale: float = 5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 3000,
+ do_patching=False,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 5.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
+ of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 3000): Maximum sequence length to use with the `prompt`.
+ do_patching (`bool`, *optional*, defaults to `False`): Whether to use patching.
+ Examples:
+ Returns:
+ [`~pipelines.flux.BriaFiboPipelineOutput`] or `tuple`: [`~pipelines.flux.BriaFiboPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ height=height,
+ width=width,
+ prompt_embeds=prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ text_ids,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ prompt_layers,
+ negative_prompt_layers,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ guidance_scale=guidance_scale,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ num_images_per_prompt=num_images_per_prompt,
+ lora_scale=lora_scale,
+ )
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if guidance_scale > 1:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_layers = [
+ torch.cat([negative_prompt_layers[i], prompt_layers[i]], dim=0) for i in range(len(prompt_layers))
+ ]
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ total_num_layers_transformer = len(self.transformer.transformer_blocks) + len(
+ self.transformer.single_transformer_blocks
+ )
+ if len(prompt_layers) >= total_num_layers_transformer:
+ # remove first layers
+ prompt_layers = prompt_layers[len(prompt_layers) - total_num_layers_transformer :]
+ else:
+ # duplicate last layer
+ prompt_layers = prompt_layers + [prompt_layers[-1]] * (total_num_layers_transformer - len(prompt_layers))
+
+ # 5. Prepare latent variables
+
+ num_channels_latents = self.transformer.config.in_channels
+ if do_patching:
+ num_channels_latents = int(num_channels_latents / 4)
+
+ latents, latent_image_ids = self.prepare_latents(
+ prompt_batch_size,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ do_patching,
+ )
+
+ latent_attention_mask = torch.ones(
+ [latents.shape[0], latents.shape[1]], dtype=latents.dtype, device=latents.device
+ )
+ if guidance_scale > 1:
+ latent_attention_mask = latent_attention_mask.repeat(2, 1)
+
+ attention_mask = torch.cat([prompt_attention_mask, latent_attention_mask], dim=1)
+ attention_mask = self._prepare_attention_mask(attention_mask) # batch, seq => batch, seq, seq
+ attention_mask = attention_mask.unsqueeze(dim=1).to(dtype=self.transformer.dtype) # for head broadcasting
+
+ if self._joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+ self._joint_attention_kwargs["attention_mask"] = attention_mask
+
+ # Adapt scheduler to dynamic shifting (resolution dependent)
+
+ if do_patching:
+ seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
+ else:
+ seq_len = (height // self.vae_scale_factor) * (width // self.vae_scale_factor)
+
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
+
+ mu = calculate_shift(
+ seq_len,
+ self.scheduler.config.base_image_seq_len,
+ self.scheduler.config.max_image_seq_len,
+ self.scheduler.config.base_shift,
+ self.scheduler.config.max_shift,
+ )
+
+ # Init sigmas and timesteps according to shift size
+ # This changes the scheduler in-place according to the dynamic scheduling
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps=num_inference_steps,
+ device=device,
+ timesteps=None,
+ sigmas=sigmas,
+ mu=mu,
+ )
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # Support old different diffusers versions
+ if len(latent_image_ids.shape) == 3:
+ latent_image_ids = latent_image_ids[0]
+
+ if len(text_ids.shape) == 3:
+ text_ids = text_ids[0]
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0]).to(
+ device=latent_model_input.device, dtype=latent_model_input.dtype
+ )
+
+ # This is predicts "v" from flow-matching or eps from diffusion
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ text_encoder_layers=prompt_layers,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ )[0]
+
+ # perform guidance
+ if guidance_scale > 1:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ image = latents
+
+ else:
+ if do_patching:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ else:
+ latents = self._unpack_latents_no_patch(latents, height, width, self.vae_scale_factor)
+
+ latents = latents.unsqueeze(dim=2)
+ latents_device = latents[0].device
+ latents_dtype = latents[0].dtype
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents_device, latents_dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents_device, latents_dtype
+ )
+ latents_scaled = [latent / latents_std + latents_mean for latent in latents]
+ latents_scaled = torch.cat(latents_scaled, dim=0)
+ image = []
+ for scaled_latent in latents_scaled:
+ curr_image = self.vae.decode(scaled_latent.unsqueeze(0), return_dict=False)[0]
+ curr_image = self.image_processor.postprocess(curr_image.squeeze(dim=2), output_type=output_type)
+ image.append(curr_image)
+ if len(image) == 1:
+ image = image[0]
+ else:
+ image = np.stack(image, axis=0)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return BriaFiboPipelineOutput(images=image)
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 3000:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 3000 but is {max_sequence_length}")
diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_output.py b/src/diffusers/pipelines/bria_fibo/pipeline_output.py
new file mode 100644
index 0000000000..f459185a2c
--- /dev/null
+++ b/src/diffusers/pipelines/bria_fibo/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class BriaFiboPipelineOutput(BaseOutput):
+ """
+ Output class for BriaFibo pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py
index f3ed700bc4..ed6c2c2105 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
- >>> model_id = "lodestones/Chroma"
- >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
+ >>> model_id = "lodestones/Chroma1-HD"
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
@@ -158,7 +158,7 @@ class ChromaPipeline(
r"""
The Chroma pipeline for text-to-image generation.
- Reference: https://huggingface.co/lodestones/Chroma/
+ Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -233,20 +233,23 @@ class ChromaPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
- attention_mask = text_inputs.attention_mask.clone()
+ tokenizer_mask = text_inputs.attention_mask
- # Chroma requires the attention mask to include one padding token
- seq_lengths = attention_mask.sum(dim=1)
- mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
- attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
+ tokenizer_mask_device = tokenizer_mask.to(device)
+ # unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
prompt_embeds = self.text_encoder(
- text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
+ text_input_ids.to(device),
+ output_hidden_states=False,
+ attention_mask=tokenizer_mask_device,
)[0]
- dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- attention_mask = attention_mask.to(dtype=dtype, device=device)
+
+ # for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
+ seq_lengths = tokenizer_mask_device.sum(dim=1)
+ mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -605,10 +608,9 @@ class ChromaPipeline(
# Extend the prompt attention mask to account for image tokens in the final sequence
attention_mask = torch.cat(
- [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device)],
+ [attention_mask, torch.ones(batch_size, sequence_length, device=attention_mask.device, dtype=torch.bool)],
dim=1,
)
- attention_mask = attention_mask.to(dtype)
return attention_mask
@@ -688,11 +690,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
- a model to generate images more aligned with `prompt` at the expense of lower image quality.
-
- Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
- the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
index 26f13fe06c..470c746e41 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
- >>> model_id = "lodestones/Chroma"
- >>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
+ >>> model_id = "lodestones/Chroma1-HD"
+ >>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... model_id,
... transformer=transformer,
@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
r"""
The Chroma pipeline for image-to-image generation.
- Reference: https://huggingface.co/lodestones/Chroma/
+ Reference: https://huggingface.co/lodestones/Chroma1-HD/
Args:
transformer ([`ChromaTransformer2DModel`]):
@@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline(
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
- attention_mask = text_inputs.attention_mask.clone()
+ tokenizer_mask = text_inputs.attention_mask
- # Chroma requires the attention mask to include one padding token
- seq_lengths = attention_mask.sum(dim=1)
- mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
- attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
+ tokenizer_mask_device = tokenizer_mask.to(device)
prompt_embeds = self.text_encoder(
- text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
+ text_input_ids.to(device),
+ output_hidden_states=False,
+ attention_mask=tokenizer_mask_device,
)[0]
- dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
- attention_mask = attention_mask.to(dtype=dtype, device=device)
+
+ seq_lengths = tokenizer_mask_device.sum(dim=1)
+ mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
+ attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
@@ -749,12 +750,12 @@ class ChromaImg2ImgPipeline(
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
- guidance_scale (`float`, *optional*, defaults to 5.0):
- Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
- a model to generate images more aligned with `prompt` at the expense of lower image quality.
-
- Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
- the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
strength (`float, *optional*, defaults to 0.9):
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
index 41303d9c5c..6de8e5747b 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
- as well as default text-to-image Stable Diffusion checkpoints
+ > as well as default text-to-image Stable Diffusion checkpoints >
([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
- Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
- those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
+ > Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
+ on > those, such as
+ [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
Args:
vae ([`AutoencoderKL`]):
diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
index 1de1d4bde7..d4c6f336df 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py
@@ -394,12 +394,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
index c763411ab5..d605eac1f2 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py
@@ -266,7 +266,7 @@ class StableDiffusion3ControlNetPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -355,7 +355,7 @@ class StableDiffusion3ControlNetPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
index c33cf979c6..9d0158c6b6 100644
--- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
+++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py
@@ -284,7 +284,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -373,7 +373,7 @@ class StableDiffusion3ControlNetInpaintingPipeline(
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
index 1545027a28..3682ddc911 100644
--- a/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
+++ b/src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py
@@ -133,8 +133,8 @@ class StableDiffusionControlNetXSPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
index 003e748274..6f484aa3e2 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py
@@ -185,8 +185,8 @@ class AltDiffusionPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -266,8 +266,8 @@ class AltDiffusionPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
index 64bd06d02e..d6bf901207 100644
--- a/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
+++ b/src/diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py
@@ -213,8 +213,8 @@ class AltDiffusionImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -294,8 +294,8 @@ class AltDiffusionImg2ImgPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
index 59c79e134e..08f8c7e26f 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
@@ -162,8 +162,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta
instance of [`DDIMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -226,8 +226,8 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
index 2d9eaa493f..fcd8bf317a 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py
@@ -62,7 +62,8 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
index 205ace65ee..ba0dd66c29 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py
@@ -111,7 +111,8 @@ class StableDiffusionInpaintPipelineLegacy(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
@@ -196,8 +197,8 @@ class StableDiffusionInpaintPipelineLegacy(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
index d81f0d2625..b7a0be57c1 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py
@@ -64,8 +64,8 @@ class StableDiffusionModelEditingPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
with_to_k ([`bool`]):
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
index 2331157ba5..c236e73bf4 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py
@@ -46,10 +46,12 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import DDPMParallelScheduler
>>> from diffusers import StableDiffusionParadigmsPipeline
- >>> scheduler = DDPMParallelScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
+ >>> scheduler = DDPMParallelScheduler.from_pretrained(
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler"
+ ... )
>>> pipe = StableDiffusionParadigmsPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", scheduler=scheduler, torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
@@ -95,8 +97,8 @@ class StableDiffusionParadigmsPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
index d000d87e6a..2a461ae20c 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py
@@ -303,7 +303,8 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin
[`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`], or [`DDPMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ Please, refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ details.
feature_extractor ([`CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
requires_safety_checker (bool):
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
index eda950998d..397fbc0d85 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py
@@ -1000,11 +1000,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
@@ -1021,11 +1017,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
index 61582853b0..9ff8e98577 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py
@@ -38,8 +38,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
index 2beb0be57b..034a022641 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py
@@ -18,7 +18,6 @@ from typing import Callable, List, Optional, Union
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ....image_processor import VaeImageProcessor
diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
index adfd899e76..2f54f4fc98 100644
--- a/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
+++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py
@@ -16,7 +16,6 @@ import inspect
from typing import Callable, List, Optional, Union
import torch
-import torch.utils.checkpoint
from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTokenizer
from ....image_processor import VaeImageProcessor
diff --git a/src/diffusers/pipelines/hunyuan_image/__init__.py b/src/diffusers/pipelines/hunyuan_image/__init__.py
new file mode 100644
index 0000000000..7da72fa12b
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/__init__.py
@@ -0,0 +1,50 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_hunyuanimage"] = ["HunyuanImagePipeline"]
+ _import_structure["pipeline_hunyuanimage_refiner"] = ["HunyuanImageRefinerPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_hunyuanimage import HunyuanImagePipeline
+ from .pipeline_hunyuanimage_refiner import HunyuanImageRefinerPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
new file mode 100644
index 0000000000..658935ccd8
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py
@@ -0,0 +1,866 @@
+# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
+
+from ...guiders import AdaptiveProjectedMixGuidance
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import HunyuanImagePipeline
+
+ >>> pipe = HunyuanImagePipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanImage-2.1-Diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, negative_prompt="", num_inference_steps=50).images[0]
+ >>> image.save("hunyuanimage.png")
+ ```
+"""
+
+
+def extract_glyph_text(prompt: str):
+ """
+ Extract text enclosed in quotes for glyph rendering.
+
+ Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
+
+ Args:
+ prompt: Input text prompt
+
+ Returns:
+ Formatted glyph text string or None if no quoted text found
+ """
+ text_prompt_texts = []
+ pattern_quote_single = r"\'(.*?)\'"
+ pattern_quote_double = r"\"(.*?)\""
+ pattern_quote_chinese_single = r"‘(.*?)’"
+ pattern_quote_chinese_double = r"“(.*?)”"
+
+ matches_quote_single = re.findall(pattern_quote_single, prompt)
+ matches_quote_double = re.findall(pattern_quote_double, prompt)
+ matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, prompt)
+ matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, prompt)
+
+ text_prompt_texts.extend(matches_quote_single)
+ text_prompt_texts.extend(matches_quote_double)
+ text_prompt_texts.extend(matches_quote_chinese_single)
+ text_prompt_texts.extend(matches_quote_chinese_double)
+
+ if text_prompt_texts:
+ glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
+ else:
+ glyph_text_formatted = None
+
+ return glyph_text_formatted
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class HunyuanImagePipeline(DiffusionPipeline):
+ r"""
+ The HunyuanImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`HunyuanImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanImage`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
+ variant.
+ tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
+ guider ([`AdaptiveProjectedMixGuidance`]):
+ [AdaptiveProjectedMixGuidance]to be used to guide the image generation.
+ ocr_guider ([`AdaptiveProjectedMixGuidance`], *optional*):
+ [AdaptiveProjectedMixGuidance] to be used to guide the image generation when text rendering is needed.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["ocr_guider", "guider"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLHunyuanImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: ByT5Tokenizer,
+ transformer: HunyuanImageTransformer2DModel,
+ guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ocr_guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ ocr_guider=ocr_guider,
+ )
+
+ self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 32
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = 1000
+ self.tokenizer_2_max_length = 128
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
+ self.prompt_template_encode_start_idx = 34
+ self.default_sample_size = 64
+
+ def _get_qwen_prompt_embeds(
+ self,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 1000,
+ template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
+ drop_idx: int = 34,
+ hidden_state_skip_layer: int = 2,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
+ ).to(device)
+
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
+
+ prompt_embeds = prompt_embeds[:, drop_idx:]
+ encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = encoder_attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def _get_byt5_prompt_embeds(
+ self,
+ tokenizer: ByT5Tokenizer,
+ text_encoder: T5EncoderModel,
+ prompt: str,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 128,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ if isinstance(prompt, list):
+ raise ValueError("byt5 prompt should be a string")
+ elif prompt is None:
+ raise ValueError("byt5 prompt should not be None")
+
+ txt_tokens = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer_max_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ ).to(device)
+
+ prompt_embeds = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask.float(),
+ )[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ batch_size: int = 1,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ template=self.prompt_template_encode,
+ drop_idx=self.prompt_template_encode_start_idx,
+ )
+
+ if prompt_embeds_2 is None:
+ prompt_embeds_2_list = []
+ prompt_embeds_mask_2_list = []
+
+ glyph_texts = [extract_glyph_text(p) for p in prompt]
+ for glyph_text in glyph_texts:
+ if glyph_text is None:
+ glyph_text_embeds = torch.zeros(
+ (1, self.tokenizer_2_max_length, self.text_encoder_2.config.d_model), device=device
+ )
+ glyph_text_embeds_mask = torch.zeros(
+ (1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
+ )
+ else:
+ glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
+ tokenizer=self.tokenizer_2,
+ text_encoder=self.text_encoder_2,
+ prompt=glyph_text,
+ device=device,
+ tokenizer_max_length=self.tokenizer_2_max_length,
+ )
+
+ prompt_embeds_2_list.append(glyph_text_embeds)
+ prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
+
+ prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
+ prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ _, seq_len_2, _ = prompt_embeds_2.shape
+ prompt_embeds_2 = prompt_embeds_2.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_2 = prompt_embeds_2.view(batch_size * num_images_per_prompt, seq_len_2, -1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask_2 = prompt_embeds_mask_2.view(batch_size * num_images_per_prompt, seq_len_2)
+
+ return prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ prompt_embeds_2=None,
+ prompt_embeds_mask_2=None,
+ negative_prompt_embeds_2=None,
+ negative_prompt_embeds_mask_2=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if prompt is None and prompt_embeds_2 is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
+ )
+
+ if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
+ )
+ if negative_prompt_embeds_2 is not None and negative_prompt_embeds_mask_2 is None:
+ raise ValueError(
+ "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ distilled_guidance_scale: Optional[float] = 3.25,
+ sigmas: Optional[List[float]] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ prompt_embeds_2: Optional[torch.Tensor] = None,
+ prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_2: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask_2: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined and negative_prompt_embeds is
+ not provided, will use an empty negative prompt. Ignored when not using guidance. ).
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ distilled_guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
+ images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
+ guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
+ ignored.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, text embeddings mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, text embeddings for ocr will be generated from `prompt` input argument.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, text embeddings mask for ocr will be generated from `prompt` input
+ argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings mask. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative text embeddings mask will be generated from `negative_prompt`
+ input argument.
+ negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings for ocr. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative text embeddings for ocr will be generated from `negative_prompt`
+ input argument.
+ negative_prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings mask for ocr. Can be used to easily tweak text inputs, *e.g.*
+ prompt weighting. If not provided, negative text embeddings mask for ocr will be generated from
+ `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ negative_prompt_embeds_2=negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. prepare prompt embeds
+
+ prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds_2=prompt_embeds_2,
+ prompt_embeds_mask_2=prompt_embeds_mask_2,
+ )
+
+ prompt_embeds = prompt_embeds.to(self.transformer.dtype)
+ prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
+
+ # select guider
+ if not torch.all(prompt_embeds_2 == 0) and self.ocr_guider is not None:
+ # prompt contains ocr and pipeline has a guider for ocr
+ guider = self.ocr_guider
+ elif self.guider is not None:
+ guider = self.guider
+ # distilled model does not use guidance method, use default guider with enabled=False
+ else:
+ guider = AdaptiveProjectedMixGuidance(enabled=False)
+
+ if guider._enabled and guider.num_conditions > 1:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ negative_prompt_embeds_2,
+ negative_prompt_embeds_mask_2,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds_2=negative_prompt_embeds_2,
+ prompt_embeds_mask_2=negative_prompt_embeds_mask_2,
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
+ negative_prompt_embeds_2 = negative_prompt_embeds_2.to(self.transformer.dtype)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=num_channels_latents,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance (for guidance-distilled model)
+ if self.transformer.config.guidance_embeds and distilled_guidance_scale is None:
+ raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
+
+ if self.transformer.config.guidance_embeds:
+ guidance = (
+ torch.tensor(
+ [distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device
+ )
+ * 1000.0
+ )
+
+ else:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ if self.transformer.config.use_meanflow:
+ if i == len(timesteps) - 1:
+ timestep_r = torch.tensor([0.0], device=device)
+ else:
+ timestep_r = timesteps[i + 1]
+ timestep_r = timestep_r.expand(latents.shape[0]).to(latents.dtype)
+ else:
+ timestep_r = None
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ "encoder_hidden_states_2": (prompt_embeds_2, negative_prompt_embeds_2),
+ "encoder_attention_mask_2": (prompt_embeds_mask_2, negative_prompt_embeds_mask_2),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = guider.prepare_inputs(guider_inputs)
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep,
+ timestep_r=timestep_r,
+ guidance=guidance,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return HunyuanImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py
new file mode 100644
index 0000000000..f38f53d9a5
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py
@@ -0,0 +1,752 @@
+# Copyright 2025 Hunyuan-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
+
+from ...guiders import AdaptiveProjectedMixGuidance
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import HunyuanImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import HunyuanImageRefinerPipeline
+
+ >>> pipe = HunyuanImageRefinerPipeline.from_pretrained(
+ ... "hunyuanvideo-community/HunyuanImage-2.1-Refiner-Diffusers", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+ >>> prompt = "A cat holding a sign that says hello world"
+ >>> image = load_image("path/to/image.png")
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(prompt, image=image, num_inference_steps=4).images[0]
+ >>> image.save("hunyuanimage.png")
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class HunyuanImageRefinerPipeline(DiffusionPipeline):
+ r"""
+ The HunyuanImage pipeline for text-to-image generation.
+
+ Args:
+ transformer ([`HunyuanImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLHunyuanImageRefiner`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["guider"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLHunyuanImageRefiner,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ transformer: HunyuanImageTransformer2DModel,
+ guider: Optional[AdaptiveProjectedMixGuidance] = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ guider=guider,
+ )
+
+ self.vae_scale_factor = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.tokenizer_max_length = 256
+ self.prompt_template_encode = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ self.prompt_template_encode_start_idx = 36
+ self.default_sample_size = 64
+ self.latent_channels = self.transformer.config.in_channels // 2 if getattr(self, "transformer", None) else 64
+
+ # Copied from diffusers.pipelines.hunyuan_image.pipeline_hunyuanimage.HunyuanImagePipeline._get_qwen_prompt_embeds
+ def _get_qwen_prompt_embeds(
+ self,
+ tokenizer: Qwen2Tokenizer,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ prompt: Union[str, List[str]] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ tokenizer_max_length: int = 1000,
+ template: str = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>",
+ drop_idx: int = 34,
+ hidden_state_skip_layer: int = 2,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ txt = [template.format(e) for e in prompt]
+ txt_tokens = tokenizer(
+ txt, max_length=tokenizer_max_length + drop_idx, padding="max_length", truncation=True, return_tensors="pt"
+ ).to(device)
+
+ encoder_hidden_states = text_encoder(
+ input_ids=txt_tokens.input_ids,
+ attention_mask=txt_tokens.attention_mask,
+ output_hidden_states=True,
+ )
+ prompt_embeds = encoder_hidden_states.hidden_states[-(hidden_state_skip_layer + 1)]
+
+ prompt_embeds = prompt_embeds[:, drop_idx:]
+ encoder_attention_mask = txt_tokens.attention_mask[:, drop_idx:]
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ encoder_attention_mask = encoder_attention_mask.to(device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ device: Optional[torch.device] = None,
+ batch_size: int = 1,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ batch_size (`int`):
+ batch size of prompts, defaults to 1
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
+ argument.
+ prompt_embeds_mask (`torch.Tensor`, *optional*):
+ Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
+ prompt_embeds_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
+ argument using self.tokenizer_2 and self.text_encoder_2.
+ """
+ device = device or self._execution_device
+
+ if prompt is None:
+ prompt = [""] * batch_size
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
+ tokenizer=self.tokenizer,
+ text_encoder=self.text_encoder,
+ prompt=prompt,
+ device=device,
+ tokenizer_max_length=self.tokenizer_max_length,
+ template=self.prompt_template_encode,
+ drop_idx=self.prompt_template_encode_start_idx,
+ )
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ def prepare_latents(
+ self,
+ image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ strength=0.25,
+ ):
+ height = int(height) // self.vae_scale_factor
+ width = int(width) // self.vae_scale_factor
+
+ shape = (batch_size, num_channels_latents, 1, height, width)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ cond_latents = strength * noise + (1 - strength) * image_latents
+
+ return latents, cond_latents
+
+ @staticmethod
+ def _reorder_image_tokens(image_latents):
+ image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
+ batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = image_latents.shape
+ image_latents = image_latents.permute(0, 2, 1, 3, 4)
+ image_latents = image_latents.reshape(
+ batch_size, num_latent_frames // 2, num_latent_channels * 2, latent_height, latent_width
+ )
+ image_latents = image_latents.permute(0, 2, 1, 3, 4).contiguous()
+
+ return image_latents
+
+ @staticmethod
+ def _restore_image_tokens_order(latents):
+ """Restore image tokens order by splitting channels and removing first frame slice."""
+ batch_size, num_latent_channels, num_latent_frames, latent_height, latent_width = latents.shape
+
+ latents = latents.permute(0, 2, 1, 3, 4) # B, F, C, H, W
+ latents = latents.reshape(
+ batch_size, num_latent_frames * 2, num_latent_channels // 2, latent_height, latent_width
+ ) # B, F*2, C//2, H, W
+
+ latents = latents.permute(0, 2, 1, 3, 4) # B, C//2, F*2, H, W
+ # Remove first frame slice
+ latents = latents[:, :, 1:]
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="sample")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
+ image_latents = self._reorder_image_tokens(image_latents)
+
+ image_latents = image_latents * self.vae.config.scaling_factor
+
+ return image_latents
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ distilled_guidance_scale: Optional[float] = 3.25,
+ image: Optional[PipelineImageInput] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 4,
+ sigmas: Optional[List[float]] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, will use an empty negative
+ prompt. Ignored when not using guidance.
+ distilled_guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ is enabled by setting `distilled_guidance_scale > 1`. Higher guidance scale encourages to generate
+ images that are closely linked to the text `prompt`, usually at the expense of lower image quality. For
+ guidance distilled models, this parameter is required. For non-distilled models, this parameter will be
+ ignored.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] or `tuple`:
+ [`~pipelines.hunyuan_image.HunyuanImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. process image
+ if image is not None and isinstance(image, torch.Tensor) and image.shape[1] == self.latent_channels:
+ image_latents = image
+ else:
+ image = self.image_processor.preprocess(image, height, width)
+ image = image.unsqueeze(2).to(device, dtype=self.vae.dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ # 3.prepare prompt embeds
+
+ if self.guider is not None:
+ guider = self.guider
+ else:
+ # distilled model does not use guidance method, use default guider with enabled=False
+ guider = AdaptiveProjectedMixGuidance(enabled=False)
+
+ requires_unconditional_embeds = guider._enabled and guider.num_conditions > 1
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ prompt_embeds = prompt_embeds.to(self.transformer.dtype)
+
+ if requires_unconditional_embeds:
+ (
+ negative_prompt_embeds,
+ negative_prompt_embeds_mask,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ batch_size=batch_size,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.to(self.transformer.dtype)
+
+ # 4. Prepare latent variables
+ latents, cond_latents = self.prepare_latents(
+ image_latents=image_latents,
+ batch_size=batch_size * num_images_per_prompt,
+ num_channels_latents=self.latent_channels,
+ height=height,
+ width=width,
+ dtype=prompt_embeds.dtype,
+ device=device,
+ generator=generator,
+ latents=latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance (this pipeline only supports guidance-distilled models)
+ if distilled_guidance_scale is None:
+ raise ValueError("`distilled_guidance_scale` is required for guidance-distilled model.")
+ guidance = (
+ torch.tensor([distilled_guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device)
+ * 1000.0
+ )
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ latent_model_input = torch.cat([latents, cond_latents], dim=1).to(self.transformer.dtype)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ # Step 1: Collect model inputs needed for the guidance method
+ # conditional inputs should always be first element in the tuple
+ guider_inputs = {
+ "encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
+ "encoder_attention_mask": (prompt_embeds_mask, negative_prompt_embeds_mask),
+ }
+
+ # Step 2: Update guider's internal state for this denoising step
+ guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t)
+
+ # Step 3: Prepare batched model inputs based on the guidance method
+ # The guider splits model inputs into separate batches for conditional/unconditional predictions.
+ # For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
+ # you will get a guider_state with two batches:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
+ # {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
+ # ]
+ # Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
+ guider_state = guider.prepare_inputs(guider_inputs)
+
+ # Step 4: Run the denoiser for each batch
+ # Each batch in guider_state represents a different conditioning (conditional, unconditional, etc.).
+ # We run the model once per batch and store the noise prediction in guider_state_batch.noise_pred.
+ for guider_state_batch in guider_state:
+ guider.prepare_models(self.transformer)
+
+ # Extract conditioning kwargs for this batch (e.g., encoder_hidden_states)
+ cond_kwargs = {
+ input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()
+ }
+
+ # e.g. "pred_cond"/"pred_uncond"
+ context_name = getattr(guider_state_batch, guider._identifier_key)
+ with self.transformer.cache_context(context_name):
+ # Run denoiser and store noise prediction in this batch
+ guider_state_batch.noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ guidance=guidance,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ **cond_kwargs,
+ )[0]
+
+ # Cleanup model (e.g., remove hooks)
+ guider.cleanup_models(self.transformer)
+
+ # Step 5: Combine predictions using the guidance method
+ # The guider takes all noise predictions from guider_state and combines them according to the guidance algorithm.
+ # Continuing the CFG example, the guider receives:
+ # guider_state = [
+ # {"encoder_hidden_states": prompt_embeds, "noise_pred": noise_pred_cond, "__guidance_identifier__": "pred_cond"}, # batch 0
+ # {"encoder_hidden_states": negative_prompt_embeds, "noise_pred": noise_pred_uncond, "__guidance_identifier__": "pred_uncond"}, # batch 1
+ # ]
+ # And extracts predictions using the __guidance_identifier__:
+ # pred_cond = guider_state[0]["noise_pred"] # extracts noise_pred_cond
+ # pred_uncond = guider_state[1]["noise_pred"] # extracts noise_pred_uncond
+ # Then applies CFG formula:
+ # noise_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
+ # Returns GuiderOutput(pred=noise_pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
+ noise_pred = guider(guider_state)[0]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
+ latents = self._restore_image_tokens_order(latents)
+
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image.squeeze(2), output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return HunyuanImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuan_image/pipeline_output.py b/src/diffusers/pipelines/hunyuan_image/pipeline_output.py
new file mode 100644
index 0000000000..1e76892a0e
--- /dev/null
+++ b/src/diffusers/pipelines/hunyuan_image/pipeline_output.py
@@ -0,0 +1,21 @@
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class HunyuanImagePipelineOutput(BaseOutput):
+ """
+ Output class for HunyuanImage pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
index c7b8022c22..73c2688975 100644
--- a/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
+++ b/src/diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py
@@ -113,7 +113,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixi
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
- _cut_context=False,
+ _cut_context=True,
attention_mask: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
):
diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py
new file mode 100644
index 0000000000..a7975bdce9
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_kandinsky import Kandinsky5T2VPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
new file mode 100644
index 0000000000..3f93aa1889
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py
@@ -0,0 +1,900 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from torch.nn import functional as F
+from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import KandinskyLoraLoaderMixin
+from ...models import AutoencoderKLHunyuanVideo
+from ...models.transformers import Kandinsky5Transformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import KandinskyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__)
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+
+ ```python
+ >>> import torch
+ >>> from diffusers import Kandinsky5T2VPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Available models:
+ >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
+ >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
+ >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
+ >>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
+
+ >>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
+ >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen."
+ >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=512,
+ ... width=768,
+ ... num_frames=121,
+ ... num_inference_steps=50,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> export_to_video(output, "output.mp4", fps=24, quality=9)
+ ```
+"""
+
+
+def basic_clean(text):
+ """Clean text using ftfy if available and unescape HTML entities."""
+ if is_ftfy_available():
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ """Normalize whitespace in text by replacing multiple spaces with single space."""
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ """Apply both basic cleaning and whitespace normalization to prompts."""
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using Kandinsky 5.0.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ transformer ([`Kandinsky5Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded video latents.
+ vae ([`AutoencoderKLHunyuanVideo`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ text_encoder ([`Qwen2_5_VLForConditionalGeneration`]):
+ Frozen text-encoder (Qwen2.5-VL).
+ tokenizer ([`AutoProcessor`]):
+ Tokenizer for Qwen2.5-VL.
+ text_encoder_2 ([`CLIPTextModel`]):
+ Frozen CLIP text encoder.
+ tokenizer_2 ([`CLIPTokenizer`]):
+ Tokenizer for CLIP.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
+ _callback_tensor_inputs = [
+ "latents",
+ "prompt_embeds_qwen",
+ "prompt_embeds_clip",
+ "negative_prompt_embeds_qwen",
+ "negative_prompt_embeds_clip",
+ ]
+
+ def __init__(
+ self,
+ transformer: Kandinsky5Transformer3DModel,
+ vae: AutoencoderKLHunyuanVideo,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2VLProcessor,
+ text_encoder_2: CLIPTextModel,
+ tokenizer_2: CLIPTokenizer,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ transformer=transformer,
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_encoder_2=text_encoder_2,
+ tokenizer_2=tokenizer_2,
+ scheduler=scheduler,
+ )
+
+ self.prompt_template = "\n".join(
+ [
+ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.",
+ "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.",
+ "Describe the location of the video, main characters or objects and their action.",
+ "Describe the dynamism of the video and presented actions.",
+ "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.",
+ "Describe the visual effects, postprocessing and transitions if they are presented in the video.",
+ "Pay attention to the order of key actions shown in the scene.<|im_end|>",
+ "<|im_start|>user\n{}<|im_end|>",
+ ]
+ )
+ self.prompt_template_encode_start_idx = 129
+
+ self.vae_scale_factor_temporal = (
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
+ )
+ self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ @staticmethod
+ def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor:
+ """
+ Create a sparse temporal attention (STA) mask for efficient video generation.
+
+ This method generates a mask that limits attention to nearby frames and spatial positions, reducing
+ computational complexity for video generation.
+
+ Args:
+ T (int): Number of temporal frames
+ H (int): Height in latent space
+ W (int): Width in latent space
+ wT (int): Temporal attention window size
+ wH (int): Height attention window size
+ wW (int): Width attention window size
+ device (str): Device to create tensor on
+
+ Returns:
+ torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W)
+ """
+ l = torch.Tensor([T, H, W]).amax()
+ r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
+ mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
+ sta_t, sta_h, sta_w = (
+ mat[:T, :T].flatten(),
+ mat[:H, :H].flatten(),
+ mat[:W, :W].flatten(),
+ )
+ sta_t = sta_t <= wT // 2
+ sta_h = sta_h <= wH // 2
+ sta_w = sta_w <= wW // 2
+ sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
+ sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
+ return sta.reshape(T * H * W, T * H * W)
+
+ def get_sparse_params(self, sample, device):
+ """
+ Generate sparse attention parameters for the transformer based on sample dimensions.
+
+ This method computes the sparse attention configuration needed for efficient video processing in the
+ transformer model.
+
+ Args:
+ sample (torch.Tensor): Input sample tensor
+ device (torch.device): Device to place tensors on
+
+ Returns:
+ Dict: Dictionary containing sparse attention parameters
+ """
+ assert self.transformer.config.patch_size[0] == 1
+ B, T, H, W, _ = sample.shape
+ T, H, W = (
+ T // self.transformer.config.patch_size[0],
+ H // self.transformer.config.patch_size[1],
+ W // self.transformer.config.patch_size[2],
+ )
+ if self.transformer.config.attention_type == "nabla":
+ sta_mask = self.fast_sta_nabla(
+ T,
+ H // 8,
+ W // 8,
+ self.transformer.config.attention_wT,
+ self.transformer.config.attention_wH,
+ self.transformer.config.attention_wW,
+ device=device,
+ )
+
+ sparse_params = {
+ "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
+ "attention_type": self.transformer.config.attention_type,
+ "to_fractal": True,
+ "P": self.transformer.config.attention_P,
+ "wT": self.transformer.config.attention_wT,
+ "wW": self.transformer.config.attention_wW,
+ "wH": self.transformer.config.attention_wH,
+ "add_sta": self.transformer.config.attention_add_sta,
+ "visual_shape": (T, H, W),
+ "method": self.transformer.config.attention_method,
+ }
+ else:
+ sparse_params = None
+
+ return sparse_params
+
+ def _encode_prompt_qwen(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using Qwen2.5-VL text encoder.
+
+ This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for
+ video generation.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ num_videos_per_prompt (int): Number of videos to generate per prompt
+ max_sequence_length (int): Maximum sequence length for tokenization
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ full_texts = [self.prompt_template.format(p) for p in prompt]
+
+ inputs = self.tokenizer(
+ text=full_texts,
+ images=None,
+ videos=None,
+ max_length=max_sequence_length + self.prompt_template_encode_start_idx,
+ truncation=True,
+ return_tensors="pt",
+ padding=True,
+ ).to(device)
+
+ embeds = self.text_encoder(
+ input_ids=inputs["input_ids"],
+ return_dict=True,
+ output_hidden_states=True,
+ )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :]
+
+ attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :]
+ cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32)
+
+ return embeds.to(dtype), cu_seqlens
+
+ def _encode_prompt_clip(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ """
+ Encode prompt using CLIP text encoder.
+
+ This method processes the input prompt through the CLIP model to generate pooled embeddings that capture
+ semantic information.
+
+ Args:
+ prompt (Union[str, List[str]]): Input prompt or list of prompts
+ device (torch.device): Device to run encoding on
+ num_videos_per_prompt (int): Number of videos to generate per prompt
+ dtype (torch.dtype): Data type for embeddings
+
+ Returns:
+ torch.Tensor: Pooled text embeddings from CLIP
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder_2.dtype
+
+ inputs = self.tokenizer_2(
+ prompt,
+ max_length=77,
+ truncation=True,
+ add_special_tokens=True,
+ padding="max_length",
+ return_tensors="pt",
+ ).to(device)
+
+ pooled_embed = self.text_encoder_2(**inputs)["pooler_output"]
+
+ return pooled_embed.to(dtype)
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes a single prompt (positive or negative) into text encoder hidden states.
+
+ This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text
+ representations for video generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ Prompt to be encoded.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos to generate per prompt.
+ max_sequence_length (`int`, *optional*, defaults to 512):
+ Maximum sequence length for text encoding.
+ device (`torch.device`, *optional*):
+ Torch device.
+ dtype (`torch.dtype`, *optional*):
+ Torch dtype.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim)
+ - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim)
+ - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size *
+ num_videos_per_prompt + 1,)
+ """
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ if not isinstance(prompt, list):
+ prompt = [prompt]
+
+ batch_size = len(prompt)
+
+ prompt = [prompt_clean(p) for p in prompt]
+
+ # Encode with Qwen2.5-VL
+ prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ dtype=dtype,
+ )
+ # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim]
+
+ # Encode with CLIP
+ prompt_embeds_clip = self._encode_prompt_clip(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ )
+ # prompt_embeds_clip shape: [batch_size, clip_embed_dim]
+
+ # Repeat embeddings for num_videos_per_prompt
+ # Qwen embeddings: repeat sequence for each video, then reshape
+ prompt_embeds_qwen = prompt_embeds_qwen.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim]
+ prompt_embeds_qwen = prompt_embeds_qwen.view(
+ batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]
+ )
+
+ # CLIP embeddings: repeat for each video
+ prompt_embeds_clip = prompt_embeds_clip.repeat(
+ 1, num_videos_per_prompt, 1
+ ) # [batch_size, num_videos_per_prompt, clip_embed_dim]
+ # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim]
+ prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1)
+
+ # Repeat cumulative sequence lengths for num_videos_per_prompt
+ # Original cu_seqlens: [0, len1, len1+len2, ...]
+ # Need to repeat the differences and reconstruct for repeated prompts
+ # Original differences (lengths) for each prompt in the batch
+ original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...]
+ # Repeat the lengths for num_videos_per_prompt
+ repeated_lengths = original_lengths.repeat_interleave(
+ num_videos_per_prompt
+ ) # [len1, len1, ..., len2, len2, ...]
+ # Reconstruct the cumulative lengths
+ repeated_cu_seqlens = torch.cat(
+ [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]
+ )
+
+ return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds_qwen=None,
+ prompt_embeds_clip=None,
+ negative_prompt_embeds_qwen=None,
+ negative_prompt_embeds_clip=None,
+ prompt_cu_seqlens=None,
+ negative_prompt_cu_seqlens=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ """
+ Validate input parameters for the pipeline.
+
+ Args:
+ prompt: Input prompt
+ negative_prompt: Negative prompt for guidance
+ height: Video height
+ width: Video width
+ prompt_embeds_qwen: Pre-computed Qwen prompt embeddings
+ prompt_embeds_clip: Pre-computed CLIP prompt embeddings
+ negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings
+ negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings
+ prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt
+ negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt
+ callback_on_step_end_tensor_inputs: Callback tensor inputs
+
+ Raises:
+ ValueError: If inputs are invalid
+ """
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ # Check for consistency within positive prompt embeddings and sequence lengths
+ if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None:
+ if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None:
+ raise ValueError(
+ "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check for consistency within negative prompt embeddings and sequence lengths
+ if (
+ negative_prompt_embeds_qwen is not None
+ or negative_prompt_embeds_clip is not None
+ or negative_prompt_cu_seqlens is not None
+ ):
+ if (
+ negative_prompt_embeds_qwen is None
+ or negative_prompt_embeds_clip is None
+ or negative_prompt_cu_seqlens is None
+ ):
+ raise ValueError(
+ "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, "
+ "all three must be provided."
+ )
+
+ # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive)
+ if prompt is None and prompt_embeds_qwen is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined."
+ )
+
+ # Validate types for prompt and negative_prompt if provided
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ if negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Prepare initial latent variables for video generation.
+
+ This method creates random noise latents or uses provided latents as starting point for the denoising process.
+
+ Args:
+ batch_size (int): Number of videos to generate
+ num_channels_latents (int): Number of channels in latent space
+ height (int): Height of generated video
+ width (int): Width of generated video
+ num_frames (int): Number of frames in video
+ dtype (torch.dtype): Data type for latents
+ device (torch.device): Device to create latents on
+ generator (torch.Generator): Random number generator
+ latents (torch.Tensor): Pre-existing latents to use
+
+ Returns:
+ torch.Tensor: Prepared latent tensor
+ """
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ if self.transformer.visual_cond:
+ # For visual conditioning, concatenate with zeros and mask
+ visual_cond = torch.zeros_like(latents)
+ visual_cond_mask = torch.zeros(
+ [
+ batch_size,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ 1,
+ ],
+ dtype=latents.dtype,
+ device=latents.device,
+ )
+ latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ """Get the current guidance scale value."""
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ """Check if classifier-free guidance is enabled."""
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ """Get the number of denoising timesteps."""
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ """Check if generation has been interrupted."""
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ height: int = 512,
+ width: int = 768,
+ num_frames: int = 121,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ prompt_embeds_clip: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_qwen: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_clip: Optional[torch.Tensor] = None,
+ prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ negative_prompt_cu_seqlens: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `512`):
+ The height in pixels of the generated video.
+ width (`int`, defaults to `768`):
+ The width in pixels of the generated video.
+ num_frames (`int`, defaults to `25`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in classifier-free guidance.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A torch generator to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`KandinskyPipelineOutput`].
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function that is called at the end of each denoising step.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length for text encoding.
+
+ Examples:
+
+ Returns:
+ [`~KandinskyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images.
+ """
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ height=height,
+ width=width,
+ prompt_embeds_qwen=prompt_embeds_qwen,
+ prompt_embeds_clip=prompt_embeds_clip,
+ negative_prompt_embeds_qwen=negative_prompt_embeds_qwen,
+ negative_prompt_embeds_clip=negative_prompt_embeds_clip,
+ prompt_cu_seqlens=prompt_cu_seqlens,
+ negative_prompt_cu_seqlens=negative_prompt_cu_seqlens,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._interrupt = False
+
+ device = self._execution_device
+ dtype = self.transformer.dtype
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ prompt = [prompt]
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds_qwen.shape[0]
+
+ # 3. Encode input prompt
+ if prompt_embeds_qwen is None:
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt(
+ prompt=prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if self.do_classifier_free_guidance:
+ if negative_prompt is None:
+ negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
+
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt]
+ elif len(negative_prompt) != len(prompt):
+ raise ValueError(
+ f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}."
+ )
+
+ if negative_prompt_embeds_qwen is None:
+ negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = (
+ self.encode_prompt(
+ prompt=negative_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_visual_dim
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare rope positions for positional encoding
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ visual_rope_pos = [
+ torch.arange(num_latent_frames, device=device),
+ torch.arange(height // self.vae_scale_factor_spatial // 2, device=device),
+ torch.arange(width // self.vae_scale_factor_spatial // 2, device=device),
+ ]
+
+ text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device)
+
+ negative_text_rope_pos = (
+ torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device)
+ if negative_prompt_cu_seqlens is not None
+ else None
+ )
+
+ # 7. Sparse Params for efficient attention
+ sparse_params = self.get_sparse_params(latents, device)
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt)
+
+ # Predict noise residual
+ pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=prompt_embeds_qwen.to(dtype),
+ pooled_projections=prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=text_rope_pos,
+ scale_factor=(1, 2, 2),
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None:
+ uncond_pred_velocity = self.transformer(
+ hidden_states=latents.to(dtype),
+ encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype),
+ pooled_projections=negative_prompt_embeds_clip.to(dtype),
+ timestep=timestep.to(dtype),
+ visual_rope_pos=visual_rope_pos,
+ text_rope_pos=negative_text_rope_pos,
+ scale_factor=(1, 2, 2),
+ sparse_params=sparse_params,
+ return_dict=True,
+ ).sample
+
+ pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity)
+ # Compute previous sample using the scheduler
+ latents[:, :, :, :, :num_channels_latents] = self.scheduler.step(
+ pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen)
+ prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip)
+ negative_prompt_embeds_qwen = callback_outputs.pop(
+ "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen
+ )
+ negative_prompt_embeds_clip = callback_outputs.pop(
+ "negative_prompt_embeds_clip", negative_prompt_embeds_clip
+ )
+
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # 8. Post-processing - extract main latents
+ latents = latents[:, :, :, :, :num_channels_latents]
+
+ # 9. Decode latents to video
+ if output_type != "latent":
+ latents = latents.to(self.vae.dtype)
+ # Reshape and normalize latents
+ video = latents.reshape(
+ batch_size,
+ num_videos_per_prompt,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ num_channels_latents,
+ )
+ video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width]
+ video = video.reshape(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+
+ # Normalize and decode through VAE
+ video = video / self.vae.config.scaling_factor
+ video = self.vae.decode(video).sample
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return KandinskyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py
new file mode 100644
index 0000000000..ed77d42a9a
--- /dev/null
+++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class KandinskyPipelineOutput(BaseOutput):
+ r"""
+ Output class for Wan pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
index 1d7733982e..59f733a498 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py
@@ -186,8 +186,8 @@ class LatentConsistencyModelImg2ImgPipeline(
supports [`LCMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
requires_safety_checker (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
index 3e96b44663..e463884618 100644
--- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
+++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py
@@ -165,8 +165,8 @@ class LatentConsistencyModelPipeline(
supports [`LCMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
requires_safety_checker (`bool`, *optional*, defaults to `True`):
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
index bc50835d19..f1bf4701e3 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py
@@ -17,7 +17,6 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput
diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
index 273e97f1ec..631539e5c6 100644
--- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
+++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py
@@ -4,7 +4,6 @@ from typing import List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
-import torch.utils.checkpoint
from ...models import UNet2DModel, VQModel
from ...schedulers import (
diff --git a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
index 5b61aaf9b6..fbf4dc23d0 100644
--- a/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
+++ b/src/diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py
@@ -49,7 +49,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers.utils import load_image
>>> pipe = LEditsPPPipelineStableDiffusion.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", variant="fp16", torch_dtype=torch.float16
... )
>>> pipe.enable_vae_tiling()
>>> pipe = pipe.to("cuda")
@@ -381,8 +381,8 @@ class LEditsPPPipelineStableDiffusion(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
index 1e94f6895f..9acff105e5 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py
@@ -121,6 +121,38 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
result = torch.lerp(latents, result, factor)
return result
+ def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
+ """
+ Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
+ smooth way using a sigmoid-based compression.
+
+ This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
+ when controlling dynamic behavior with a `compression` factor.
+
+ Args:
+ latents : torch.Tensor
+ Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
+ compression : float
+ Compression strength in the range [0, 1].
+ - 0.0: No tone-mapping (identity transform)
+ - 1.0: Full compression effect
+
+ Returns:
+ torch.Tensor
+ The tone-mapped latent tensor of the same shape as input.
+ """
+ # Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
+ scale_factor = compression * 0.75
+ abs_latents = torch.abs(latents)
+
+ # Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
+ # When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
+ sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
+ scales = 1.0 - 0.8 * scale_factor * sigmoid_term
+
+ filtered = latents * scales
+ return filtered
+
@staticmethod
# Copied from diffusers.pipelines.ltx.pipeline_ltx.LTXPipeline._normalize_latents
def _normalize_latents(
@@ -196,7 +228,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
)
self.vae.disable_tiling()
- def check_inputs(self, video, height, width, latents):
+ def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -205,6 +237,9 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
+ if not (0 <= tone_map_compression_ratio <= 1):
+ raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
+
@torch.no_grad()
def __call__(
self,
@@ -215,6 +250,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
+ tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
@@ -224,6 +260,7 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
height=height,
width=width,
latents=latents,
+ tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
@@ -266,6 +303,9 @@ class LTXLatentUpsamplePipeline(DiffusionPipeline):
else:
latents = latents_upsampled
+ if tone_map_compression_ratio > 0.0:
+ latents = self.tone_map_latents(latents, tone_map_compression_ratio)
+
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
diff --git a/src/diffusers/pipelines/lucy/__init__.py b/src/diffusers/pipelines/lucy/__init__.py
new file mode 100644
index 0000000000..580e1f37f3
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/__init__.py
@@ -0,0 +1,47 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lucy_edit"] = ["LucyEditPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_lucy_edit import LucyEditPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
new file mode 100644
index 0000000000..69f69d5768
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_lucy_edit.py
@@ -0,0 +1,735 @@
+# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
+# Copyright 2025 The Decart AI Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Modifications by Decart AI Team:
+# - Based on pipeline_wan.py, but with supports recieving a condition video appended to the channel dimension.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import regex as re
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import WanLoraLoaderMixin
+from ...models import AutoencoderKLWan, WanTransformer3DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import LucyPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```python
+ >>> from typing import List
+
+ >>> import torch
+ >>> from PIL import Image
+
+ >>> from diffusers import AutoencoderKLWan, LucyEditPipeline
+ >>> from diffusers.utils import export_to_video, load_video
+
+ >>> # Arguments
+ >>> url = "https://d2drjpuinn46lb.cloudfront.net/painter_original_edit.mp4"
+ >>> prompt = "Change the apron and blouse to a classic clown costume: satin polka-dot jumpsuit in bright primary colors, ruffled white collar, oversized pom-pom buttons, white gloves, oversized red shoes, red foam nose; soft window light from left, eye-level medium shot, natural folds and fabric highlights."
+ >>> negative_prompt = ""
+ >>> num_frames = 81
+ >>> height = 480
+ >>> width = 832
+
+
+ >>> # Load video
+ >>> def convert_video(video: List[Image.Image]) -> List[Image.Image]:
+ ... video = load_video(url)[:num_frames]
+ ... video = [video[i].resize((width, height)) for i in range(num_frames)]
+ ... return video
+
+
+ >>> video = load_video(url, convert_method=convert_video)
+
+ >>> # Load model
+ >>> model_id = "decart-ai/Lucy-Edit-Dev"
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+ >>> pipe = LucyEditPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> # Generate video
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... video=video,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... num_frames=81,
+ ... guidance_scale=5.0,
+ ... ).frames[0]
+
+ >>> # Export video
+ >>> export_to_video(output, "output.mp4", fps=24)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class LucyEditPipeline(DiffusionPipeline, WanLoraLoaderMixin):
+ r"""
+ Pipeline for video-to-video generation using Lucy Edit.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`WanTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
+ stages. If not provided, only `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer", "transformer_2"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ vae: AutoencoderKLWan,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: Optional[WanTransformer3DModel] = None,
+ transformer_2: Optional[WanTransformer3DModel] = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False, # Wan2.2 ti2v
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ transformer_2=transformer_2,
+ )
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.register_to_config(expand_timesteps=expand_timesteps)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if video is None:
+ raise ValueError("`video` is required, received None.")
+
+ def prepare_latents(
+ self,
+ video: Optional[torch.Tensor] = None,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ num_latent_frames = (
+ (video.size(2) - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.size(1)
+ )
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ height // self.vae_scale_factor_spatial,
+ width // self.vae_scale_factor_spatial,
+ )
+ # Prepare noise latents
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # Prepare condition latents
+ condition_latents = [
+ retrieve_latents(self.vae.encode(vid.unsqueeze(0)), sample_mode="argmax") for vid in video
+ ]
+
+ condition_latents = torch.cat(condition_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, dtype
+ )
+
+ condition_latents = (condition_latents - latents_mean) * latents_std
+
+ # Check shapes
+ assert latents.shape == condition_latents.shape, (
+ f"Latents shape {latents.shape} does not match expected shape {condition_latents.shape}. Please check the input."
+ )
+
+ return latents, condition_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[Image.Image]`):
+ The video to use as the condition for the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (`guidance_scale` < `1`).
+ height (`int`, defaults to `480`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `832`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `81`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`LucyPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
+ truncated. If the prompt is shorter, it will be padded to this length.
+
+ Examples:
+
+ Returns:
+ [`~LucyPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`LucyPipelineOutput`] is returned, otherwise a `tuple` is returned where
+ the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ video,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
+ self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = (
+ self.transformer.config.out_channels
+ if self.transformer is not None
+ else self.transformer_2.config.out_channels
+ )
+ video = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition_latents = self.prepare_latents(
+ video,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
+ # latent_model_input = latents.to(transformer_dtype)
+ latent_model_input = torch.cat([latents, condition_latents], dim=1).to(transformer_dtype)
+ # latent_model_input = torch.cat([latents, latents], dim=1).to(transformer_dtype)
+ if self.config.expand_timesteps:
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return LucyPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/lucy/pipeline_output.py b/src/diffusers/pipelines/lucy/pipeline_output.py
new file mode 100644
index 0000000000..cf9ea91fd1
--- /dev/null
+++ b/src/diffusers/pipelines/lucy/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class LucyPipelineOutput(BaseOutput):
+ r"""
+ Output class for Lucy pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
index da991aefbd..92ec16fd45 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_depth.py
@@ -86,15 +86,14 @@ class MarigoldDepthOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted depth maps with values in the range [0, 1]. The shape is $numimages \times 1 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 1$ for `np.ndarray`.
+ Predicted depth maps with values in the range [0, 1]. The shape is `numimages × 1 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
index c809de18f4..bef9ca77c7 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py
@@ -99,17 +99,17 @@ class MarigoldIntrinsicsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted image intrinsics with values in the range [0, 1]. The shape is $(numimages * numtargets) \times 3
- \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times height \times width
- \times 3$ for `np.ndarray`, where `numtargets` corresponds to the number of predicted target modalities of
- the intrinsic image decomposition.
+ Predicted image intrinsics with values in the range [0, 1]. The shape is `(numimages * numtargets) × 3 ×
+ height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for `np.ndarray`,
+ where `numtargets` corresponds to the number of predicted target modalities of the intrinsic image
+ decomposition.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $(numimages *
- numtargets) \times 3 \times height \times width$ for `torch.Tensor` or $(numimages * numtargets) \times
- height \times width \times 3$ for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `(numimages *
+ numtargets) × 3 × height × width` for `torch.Tensor` or `(numimages * numtargets) × height × width × 3` for
+ `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $(numimages * numensemble) \times (numtargets * 4) \times latentheight \times latentwidth$.
+ The shape is `(numimages * numensemble) × (numtargets * 4) × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
index 192ed590a4..485a39c995 100644
--- a/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
+++ b/src/diffusers/pipelines/marigold/pipeline_marigold_normals.py
@@ -81,15 +81,14 @@ class MarigoldNormalsOutput(BaseOutput):
Args:
prediction (`np.ndarray`, `torch.Tensor`):
- Predicted normals with values in the range [-1, 1]. The shape is $numimages \times 3 \times height \times
- width$ for `torch.Tensor` or $numimages \times height \times width \times 3$ for `np.ndarray`.
+ Predicted normals with values in the range [-1, 1]. The shape is `numimages × 3 × height × width` for
+ `torch.Tensor` or `numimages × height × width × 3` for `np.ndarray`.
uncertainty (`None`, `np.ndarray`, `torch.Tensor`):
- Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is $numimages
- \times 1 \times height \times width$ for `torch.Tensor` or $numimages \times height \times width \times 1$
- for `np.ndarray`.
+ Uncertainty maps computed from the ensemble, with values in the range [0, 1]. The shape is `numimages × 1 ×
+ height × width` for `torch.Tensor` or `numimages × height × width × 1` for `np.ndarray`.
latent (`None`, `torch.Tensor`):
Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
- The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
+ The shape is `numimages * numensemble × 4 × latentheight × latentwidth`.
"""
prediction: Union[np.ndarray, torch.Tensor]
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
index de66871922..1abef01430 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
@@ -80,7 +80,10 @@ EXAMPLE_DOC_STRING = """
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... enable_pag=True,
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
@@ -202,8 +205,8 @@ class StableDiffusionControlNetPAGPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
index 4c02b3dd6d..2781af7890 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py
@@ -93,7 +93,10 @@ EXAMPLE_DOC_STRING = """
... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
... )
>>> pipe = AutoPipelineForInpainting.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
+ ... controlnet=controlnet,
+ ... torch_dtype=torch.float16,
+ ... enable_pag=True,
... )
>>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
@@ -150,17 +153,14 @@ class StableDiffusionControlNetPAGInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
-
-
- This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
- ([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
- default text-to-image Stable Diffusion checkpoints
- ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
- Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
+ > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
+ ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
+ as well as > default text-to-image Stable Diffusion checkpoints >
+ ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
+ Default text-to-image > Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
+ on those, such as >
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
-
-
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
@@ -179,8 +179,8 @@ class StableDiffusionControlNetPAGInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -1332,7 +1332,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
# 7.1 Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
index 349d006aad..ea64f8be2c 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py
@@ -57,7 +57,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import AutoPipelineForText2Image
>>> pipe = AutoPipelineForText2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
... )
>>> pipe = pipe.to("cuda")
@@ -190,8 +190,8 @@ class StableDiffusionPAGPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -272,8 +272,8 @@ class StableDiffusionPAGPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
index acb4e52340..941b675099 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
@@ -237,7 +237,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -326,7 +326,7 @@ class StableDiffusion3PAGPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSin
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
index e1819a79fb..f40dd52fc2 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
@@ -253,7 +253,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -342,7 +342,7 @@ class StableDiffusion3PAGImg2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
index e9a846b5e2..8351112ce4 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py
@@ -61,7 +61,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers.utils import load_image
>>> pipe = AutoPipelineForImage2Image.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5",
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5",
... torch_dtype=torch.float16,
... enable_pag=True,
... )
@@ -185,8 +185,8 @@ class StableDiffusionPAGImg2ImgPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -267,8 +267,8 @@ class StableDiffusionPAGImg2ImgPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
index ee9d20f368..6b1b294e10 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py
@@ -58,7 +58,7 @@ EXAMPLE_DOC_STRING = """
>>> from diffusers import AutoPipelineForInpainting
>>> pipe = AutoPipelineForInpainting.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, enable_pag=True
... )
>>> pipe = pipe.to("cuda")
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
@@ -217,8 +217,8 @@ class StableDiffusionPAGInpaintPipeline(
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -299,8 +299,8 @@ class StableDiffusionPAGInpaintPipeline(
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
- " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
- " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- stable-diffusion-v1-5/stable-diffusion-v1-5"
+ " \n- stable-diffusion-v1-5/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
@@ -1183,7 +1183,7 @@ class StableDiffusionPAGInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
index 2e12a4a97f..2a8f7a448d 100644
--- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
+++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
@@ -1501,7 +1501,7 @@ class StableDiffusionXLPAGInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
index 3e22c9a845..c09992befb 100644
--- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
+++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py
@@ -158,11 +158,7 @@ def prepare_mask_and_masked_image(image, mask):
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
_last_supported_version = "0.33.1"
r"""
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
Pipeline for image-guided image inpainting using Stable Diffusion.
@@ -183,8 +179,8 @@ class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableD
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index f69968022e..2724c764c7 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -276,12 +276,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
Can be used to overwrite load and saveable variables (the pipeline components) of the specific pipeline
class. The overwritten components are passed directly to the pipelines `__init__` method.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`.
Examples:
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index ee767eddcc..8868e942ce 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -19,12 +19,12 @@ import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
+import httpx
import requests
import torch
from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from .. import __version__
from ..utils import (
@@ -33,6 +33,7 @@ from ..utils import (
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
+ _maybe_remap_transformers_class,
deprecate,
get_class_from_dynamic_module,
is_accelerate_available,
@@ -48,10 +49,12 @@ from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transfo
if is_transformers_available():
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizerBase
- from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
+ if is_transformers_version("<=", "4.56.2"):
+ from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
+
if is_accelerate_available():
import accelerate
from accelerate import dispatch_model
@@ -73,6 +76,7 @@ LOADABLE_CLASSES = {
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
+ "BaseGuidance": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
@@ -112,7 +116,9 @@ def is_safetensors_compatible(filenames, passed_components=None, folder_names=No
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -191,7 +197,9 @@ def filter_model_files(filenames):
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
allowed_extensions = [wn.split(".")[-1] for wn in weight_names]
@@ -212,7 +220,9 @@ def variant_compatible_siblings(filenames, variant=None, ignore_patterns=None) -
]
if is_transformers_available():
- weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
+ weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME]
+ if is_transformers_version("<=", "4.56.2"):
+ weight_names += [TRANSFORMERS_FLAX_WEIGHTS_NAME]
# model_pytorch, diffusion_model_pytorch, ...
weight_prefixes = [w.split(".")[0] for w in weight_names]
@@ -348,6 +358,11 @@ def maybe_raise_or_warn(
"""Simple helper method to raise or warn in case incorrect module has been passed"""
if not is_pipeline_module:
library = importlib.import_module(library_name)
+
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -382,6 +397,11 @@ def simple_get_class_obj(library_name, class_name):
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
+
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
class_obj = getattr(library, class_name)
return class_obj
@@ -408,6 +428,10 @@ def get_class_obj_and_candidates(
# else we just import it from the library.
library = importlib.import_module(library_name)
+ # Handle deprecated Transformers classes
+ if library_name == "transformers":
+ class_name = _maybe_remap_transformers_class(class_name) or class_name
+
class_obj = getattr(library, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
@@ -830,6 +854,9 @@ def load_sub_model(
else:
loading_kwargs["low_cpu_mem_usage"] = False
+ if is_transformers_model and is_transformers_version(">=", "4.57.0"):
+ loading_kwargs.pop("offload_state_dict")
+
if (
quantization_config is not None
and isinstance(quantization_config, PipelineQuantizationConfig)
@@ -855,6 +882,9 @@ def load_sub_model(
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
+ skip_keys = None
+ if hasattr(loaded_sub_model, "_skip_keys") and loaded_sub_model._skip_keys is not None:
+ skip_keys = loaded_sub_model._skip_keys
if needs_offloading_to_cpu:
dispatch_model(
@@ -863,9 +893,10 @@ def load_sub_model(
device_map=device_map,
force_hooks=True,
main_device=0,
+ skip_keys=skip_keys,
)
else:
- dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True)
+ dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True, skip_keys=skip_keys)
return loaded_sub_model
@@ -1102,7 +1133,7 @@ def _download_dduf_file(
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 01b3c56777..392d5fb3fe 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -23,6 +23,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
+import httpx
import numpy as np
import PIL.Image
import requests
@@ -36,9 +37,8 @@ from huggingface_hub import (
read_dduf_file,
snapshot_download,
)
-from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args
+from huggingface_hub.utils import HfHubHTTPError, OfflineModeIsEnabled, validate_hf_hub_args
from packaging import version
-from requests.exceptions import HTTPError
from tqdm.auto import tqdm
from typing_extensions import Self
@@ -372,12 +372,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
-
-
- If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
- the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
-
-
+ > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
+ Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
@@ -627,11 +623,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
`torch.float32` is used.
custom_pipeline (`str`, *optional*):
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
Can be either:
@@ -716,12 +708,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
dduf_file(`str`, *optional*):
Load weights from the specified dduf file.
-
-
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
- auth login`.
-
-
+ > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
+ with `hf > auth login`.
Examples:
@@ -1508,11 +1496,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- A path to a *directory* (`./my_pipeline_directory/`) containing a custom pipeline. The directory
must contain a file called `pipeline.py` that defines the custom pipeline.
-
-
- 🧪 This is an experimental feature and may change in the future.
-
-
+ > [!WARNING] > 🧪 This is an experimental feature and may change in the future.
For more information on how to load and create custom pipelines, take a look at [How to contribute a
community pipeline](https://huggingface.co/docs/diffusers/main/en/using-diffusers/contribute_pipeline).
@@ -1566,12 +1550,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
`os.PathLike`:
A path to the downloaded pipeline.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login
-
-
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login
"""
cache_dir = kwargs.pop("cache_dir", None)
@@ -1616,7 +1596,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if not local_files_only:
try:
info = model_info(pretrained_model_name, token=token, revision=revision)
- except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e:
+ except (HfHubHTTPError, OfflineModeIsEnabled, requests.ConnectionError, httpx.NetworkError) as e:
logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
local_files_only = True
model_info_call_error = e # save error to reraise it if model is not cached locally
@@ -1944,12 +1924,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
option is enabled, you should observe lower GPU memory usage and a potential speed up during inference. Speed
up during training is not guaranteed.
-
-
- ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
- precedent.
-
-
+ > [!WARNING] > ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient
+ attention takes > precedent.
Parameters:
attention_op (`Callable`, *optional*):
@@ -2005,13 +1981,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
in slices to compute attention in several steps. For more than one attention head, the computation is performed
sequentially over each head. This is useful to save some memory in exchange for a small speed decrease.
-
-
- ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch
- 2.0 or xFormers. These attention computations are already very memory efficient so you won't need to enable
- this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious slow downs!
-
-
+ > [!WARNING] > ⚠️ Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA)
+ from PyTorch > 2.0 or xFormers. These attention computations are already very memory efficient so you won't
+ need to enable > this function. If you enable attention slicing with SDPA or xFormers, it can lead to serious
+ slow downs!
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
@@ -2288,11 +2261,7 @@ class StableDiffusionMixin:
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
@@ -2317,11 +2286,7 @@ class StableDiffusionMixin:
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
"""Disable QKV projection fusion if enabled.
-
-
- This API is 🧪 experimental.
-
-
+ > [!WARNING] > This API is 🧪 experimental.
Args:
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
diff --git a/src/diffusers/pipelines/prx/__init__.py b/src/diffusers/pipelines/prx/__init__.py
new file mode 100644
index 0000000000..87aaefbd13
--- /dev/null
+++ b/src/diffusers/pipelines/prx/__init__.py
@@ -0,0 +1,63 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_additional_imports = {}
+_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_prx"] = ["PRXPipeline"]
+
+# Import T5GemmaEncoder for pipeline loading compatibility
+try:
+ if is_transformers_available():
+ import transformers
+ from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+ _additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
+ # Patch transformers module directly for serialization
+ if not hasattr(transformers, "T5GemmaEncoder"):
+ transformers.T5GemmaEncoder = T5GemmaEncoder
+except ImportError:
+ pass
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .pipeline_output import PRXPipelineOutput
+ from .pipeline_prx import PRXPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
+ for name, value in _additional_imports.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/prx/pipeline_output.py b/src/diffusers/pipelines/prx/pipeline_output.py
new file mode 100644
index 0000000000..ea1bc9bf41
--- /dev/null
+++ b/src/diffusers/pipelines/prx/pipeline_output.py
@@ -0,0 +1,35 @@
+# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import List, Union
+
+import numpy as np
+import PIL.Image
+
+from ...utils import BaseOutput
+
+
+@dataclass
+class PRXPipelineOutput(BaseOutput):
+ """
+ Output class for PRX pipelines.
+
+ Args:
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ """
+
+ images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py
new file mode 100644
index 0000000000..a3bd3e6b45
--- /dev/null
+++ b/src/diffusers/pipelines/prx/pipeline_prx.py
@@ -0,0 +1,767 @@
+# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+from typing import Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import (
+ AutoTokenizer,
+ GemmaTokenizerFast,
+ T5TokenizerFast,
+)
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+from diffusers.image_processor import PixArtImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderDC, AutoencoderKL
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+DEFAULT_RESOLUTION = 512
+
+ASPECT_RATIO_256_BIN = {
+ "0.46": [160, 352],
+ "0.6": [192, 320],
+ "0.78": [224, 288],
+ "1.0": [256, 256],
+ "1.29": [288, 224],
+ "1.67": [320, 192],
+ "2.2": [352, 160],
+}
+
+ASPECT_RATIO_512_BIN = {
+ "0.5": [352, 704],
+ "0.57": [384, 672],
+ "0.6": [384, 640],
+ "0.68": [416, 608],
+ "0.78": [448, 576],
+ "0.88": [480, 544],
+ "1.0": [512, 512],
+ "1.13": [544, 480],
+ "1.29": [576, 448],
+ "1.46": [608, 416],
+ "1.67": [640, 384],
+ "1.75": [672, 384],
+ "2.0": [704, 352],
+}
+
+logger = logging.get_logger(__name__)
+
+
+class TextPreprocessor:
+ """Text preprocessing utility for PRXPipeline."""
+
+ def __init__(self):
+ """Initialize text preprocessor."""
+ self.bad_punct_regex = re.compile(
+ r"["
+ + "#®•©™&@·º½¾¿¡§~"
+ + r"\)"
+ + r"\("
+ + r"\]"
+ + r"\["
+ + r"\}"
+ + r"\{"
+ + r"\|"
+ + r"\\"
+ + r"\/"
+ + r"\*"
+ + r"]{1,}"
+ )
+
+ def clean_text(self, text: str) -> str:
+ """Clean text using comprehensive text processing logic."""
+ # See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
+ text = str(text)
+ text = ul.unquote_plus(text)
+ text = text.strip().lower()
+ text = re.sub("", "person", text)
+
+ # Remove all urls:
+ text = re.sub(
+ r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
+ "",
+ text,
+ ) # regex for urls
+
+ # @
+ text = re.sub(r"@[\w\d]+\b", "", text)
+
+ # 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
+ text = re.sub(r"[\u31c0-\u31ef]+", "", text)
+ text = re.sub(r"[\u31f0-\u31ff]+", "", text)
+ text = re.sub(r"[\u3200-\u32ff]+", "", text)
+ text = re.sub(r"[\u3300-\u33ff]+", "", text)
+ text = re.sub(r"[\u3400-\u4dbf]+", "", text)
+ text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
+ text = re.sub(r"[\u4e00-\u9fff]+", "", text)
+
+ # все виды тире / all types of dash --> "-"
+ text = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
+ "-",
+ text,
+ )
+
+ # кавычки к одному стандарту
+ text = re.sub(r"[`´«»" "¨]", '"', text)
+ text = re.sub(r"['']", "'", text)
+
+ # " and &
+ text = re.sub(r""?", "", text)
+ text = re.sub(r"&", "", text)
+
+ # ip addresses:
+ text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
+
+ # article ids:
+ text = re.sub(r"\d:\d\d\s+$", "", text)
+
+ # \n
+ text = re.sub(r"\\n", " ", text)
+
+ # "#123", "#12345..", "123456.."
+ text = re.sub(r"#\d{1,3}\b", "", text)
+ text = re.sub(r"#\d{5,}\b", "", text)
+ text = re.sub(r"\b\d{6,}\b", "", text)
+
+ # filenames:
+ text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
+
+ # Clean punctuation
+ text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
+ text = re.sub(r"[\.]{2,}", r" ", text)
+
+ text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ text = re.sub(r"\s+\.\s+", r" ", text) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, text)) > 3:
+ text = re.sub(regex2, " ", text)
+
+ # Basic cleaning
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ text = text.strip()
+
+ # Clean alphanumeric patterns
+ text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
+ text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
+ text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
+
+ # Common spam patterns
+ text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
+ text = re.sub(r"(free\s)?download(\sfree)?", "", text)
+ text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
+ text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
+ text = re.sub(r"\bpage\s+\d+\b", "", text)
+
+ text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
+ text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
+
+ # Final cleanup
+ text = re.sub(r"\b\s+\:\s+", r": ", text)
+ text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
+ text = re.sub(r"\s+", " ", text)
+
+ text.strip()
+
+ text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
+ text = re.sub(r"^[\'\_,\-\:;]", r"", text)
+ text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
+ text = re.sub(r"^\.\S+$", "", text)
+
+ return text.strip()
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import PRXPipeline
+
+ >>> # Load pipeline with from_pretrained
+ >>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
+ >>> pipe.to("cuda")
+
+ >>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
+ >>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
+ >>> image.save("prx_output.png")
+ ```
+"""
+
+
+class PRXPipeline(
+ DiffusionPipeline,
+ LoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+):
+ r"""
+ Pipeline for text-to-image generation using PRX Transformer.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ transformer ([`PRXTransformer2DModel`]):
+ The PRX transformer model to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ text_encoder ([`T5GemmaEncoder`]):
+ Text encoder model for encoding prompts.
+ tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
+ Tokenizer for the text encoder.
+ vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ _optional_components = ["vae"]
+
+ def __init__(
+ self,
+ transformer: PRXTransformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ text_encoder: T5GemmaEncoder,
+ tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
+ vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
+ default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
+ ):
+ super().__init__()
+
+ if PRXTransformer2DModel is None:
+ raise ImportError(
+ "PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
+ )
+
+ self.text_preprocessor = TextPreprocessor()
+ self.default_sample_size = default_sample_size
+ self._guidance_scale = 1.0
+
+ self.register_modules(
+ transformer=transformer,
+ scheduler=scheduler,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ )
+
+ self.register_to_config(default_sample_size=self.default_sample_size)
+
+ if vae is not None:
+ self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ else:
+ self.image_processor = None
+
+ @property
+ def vae_scale_factor(self):
+ if self.vae is None:
+ return 8
+ if hasattr(self.vae, "spatial_compression_ratio"):
+ return self.vae.spatial_compression_ratio
+ else: # Flux VAE
+ return 2 ** (len(self.vae.config.block_out_channels) - 1)
+
+ @property
+ def do_classifier_free_guidance(self):
+ """Check if classifier-free guidance is enabled based on guidance scale."""
+ return self._guidance_scale > 1.0
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ def get_default_resolution(self):
+ """Determine the default resolution based on the loaded VAE and config.
+
+ Returns:
+ int: The default sample size (height/width) to use for generation.
+ """
+ default_from_config = getattr(self.config, "default_sample_size", None)
+ if default_from_config is not None:
+ return default_from_config
+
+ return DEFAULT_RESOLUTION
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[torch.Generator] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ """Prepare initial latents for the diffusion process."""
+ if latents is None:
+ spatial_compression = self.vae_scale_factor
+ latent_height, latent_width = (
+ height // spatial_compression,
+ width // spatial_compression,
+ )
+ shape = (batch_size, num_channels_latents, latent_height, latent_width)
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+ return latents
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ ):
+ """Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
+ if device is None:
+ device = self._execution_device
+
+ if prompt_embeds is None:
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ # Encode the prompts
+ prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
+ self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
+ )
+
+ # Duplicate embeddings for each generation per prompt
+ if num_images_per_prompt > 1:
+ # Repeat prompt embeddings
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if prompt_attention_mask is not None:
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ # Repeat negative embeddings if using CFG
+ if do_classifier_free_guidance and negative_prompt_embeds is not None:
+ bs_embed, seq_len, _ = negative_prompt_embeds.shape
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if negative_prompt_attention_mask is not None:
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+
+ return (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds if do_classifier_free_guidance else None,
+ negative_prompt_attention_mask if do_classifier_free_guidance else None,
+ )
+
+ def _tokenize_prompts(self, prompts: List[str], device: torch.device):
+ """Tokenize and clean prompts."""
+ cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
+ tokens = self.tokenizer(
+ cleaned,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
+
+ def _encode_prompt_standard(
+ self,
+ prompt: List[str],
+ device: torch.device,
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ ):
+ """Encode prompt using standard text encoder and tokenizer with batch processing."""
+ batch_size = len(prompt)
+
+ if do_classifier_free_guidance:
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt] * batch_size
+
+ prompts_to_encode = negative_prompt + prompt
+ else:
+ prompts_to_encode = prompt
+
+ input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
+
+ with torch.no_grad():
+ embeddings = self.text_encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ output_hidden_states=True,
+ )["last_hidden_state"]
+
+ if do_classifier_free_guidance:
+ uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
+ uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
+ else:
+ text_embeddings = embeddings
+ cross_attn_mask = attention_mask
+ uncond_text_embeddings = None
+ uncond_cross_attn_mask = None
+
+ return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
+
+ def check_inputs(
+ self,
+ prompt: Union[str, List[str]],
+ height: int,
+ width: int,
+ guidance_scale: float,
+ callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ """Check that all inputs are in correct format."""
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+
+ if prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
+ raise ValueError(
+ "When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
+ "`negative_prompt_embeds` must also be provided for classifier-free guidance."
+ )
+
+ spatial_compression = self.vae_scale_factor
+ if height % spatial_compression != 0 or width % spatial_compression != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
+ )
+
+ if guidance_scale < 1.0:
+ raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
+
+ if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ timesteps: List[int] = None,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ use_resolution_binning: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
+ instead.
+ negative_prompt (`str`, *optional*, defaults to `""`):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 28):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
+ empty string.
+ prompt_attention_mask (`torch.BoolTensor`, *optional*):
+ Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
+ from `prompt` input argument.
+ negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
+ Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
+ attention mask will be generated from an empty string.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
+ use_resolution_binning (`bool`, *optional*, defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
+ to the requested resolution. Useful for generating non-square images at optimal resolutions.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
+ `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
+ in the `._callback_tensor_inputs` attribute.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
+ True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
+ """
+
+ # 0. Set height and width
+ default_resolution = self.get_default_resolution()
+ height = height or default_resolution
+ width = width or default_resolution
+
+ if use_resolution_binning:
+ if self.image_processor is None:
+ raise ValueError(
+ "Resolution binning requires a VAE with image_processor, but VAE is not available. "
+ "Set use_resolution_binning=False or provide a VAE."
+ )
+ if self.default_sample_size <= 256:
+ aspect_ratio_bin = ASPECT_RATIO_256_BIN
+ else:
+ aspect_ratio_bin = ASPECT_RATIO_512_BIN
+
+ # Store original dimensions
+ orig_height, orig_width = height, width
+ # Map to closest resolution in the bin
+ height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ guidance_scale,
+ callback_on_step_end_tensor_inputs,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if self.vae is None and output_type not in ["latent", "pt"]:
+ raise ValueError(
+ f"VAE is required for output_type='{output_type}' but it is not available. "
+ "Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
+ )
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Use execution device (handles offloading scenarios including group offloading)
+ device = self._execution_device
+
+ self._guidance_scale = guidance_scale
+
+ # 2. Encode input prompt
+ text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
+ prompt,
+ device,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ )
+ # Expose standard names for callbacks parity
+ prompt_embeds = text_embeddings
+ negative_prompt_embeds = uncond_text_embeddings
+
+ # 3. Prepare timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ self.num_timesteps = len(timesteps)
+
+ # 4. Prepare latent variables
+ if self.vae is not None:
+ num_channels_latents = self.vae.config.latent_channels
+ else:
+ # When vae is None, get latent channels from transformer
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ text_embeddings.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare extra step kwargs
+ extra_step_kwargs = {}
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_eta:
+ extra_step_kwargs["eta"] = 0.0
+
+ # 6. Prepare cross-attention embeddings and masks
+ if self.do_classifier_free_guidance:
+ ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
+ ca_mask = None
+ if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
+ ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
+ else:
+ ca_embed = text_embeddings
+ ca_mask = cross_attn_mask
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # Duplicate latents if using classifier-free guidance
+ if self.do_classifier_free_guidance:
+ latents_in = torch.cat([latents, latents], dim=0)
+ # Normalize timestep for the transformer
+ t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
+ else:
+ latents_in = latents
+ # Normalize timestep for the transformer
+ t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
+
+ # Forward through transformer
+ noise_pred = self.transformer(
+ hidden_states=latents_in,
+ timestep=t_cont,
+ encoder_hidden_states=ca_embed,
+ attention_mask=ca_mask,
+ return_dict=False,
+ )[0]
+
+ # Apply CFG
+ if self.do_classifier_free_guidance:
+ noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
+ noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
+
+ # Compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_on_step_end(self, i, t, callback_kwargs)
+
+ # Call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ # 8. Post-processing
+ if output_type == "latent" or (output_type == "pt" and self.vae is None):
+ image = latents
+ else:
+ # Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
+ scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
+ shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
+ latents = (latents / scaling_factor) + shift_factor
+ # Decode using VAE (AutoencoderKL or AutoencoderDC)
+ image = self.vae.decode(latents, return_dict=False)[0]
+ # Resize back to original resolution if using binning
+ if use_resolution_binning:
+ image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
+
+ # Use standard image processor for post-processing
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return PRXPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/qwenimage/__init__.py b/src/diffusers/pipelines/qwenimage/__init__.py
index 36d92917fd..2400632ba2 100644
--- a/src/diffusers/pipelines/qwenimage/__init__.py
+++ b/src/diffusers/pipelines/qwenimage/__init__.py
@@ -28,6 +28,7 @@ else:
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
+ _import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
@@ -43,6 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_qwenimage_controlnet_inpaint import QwenImageControlNetInpaintPipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
+ from .pipeline_qwenimage_edit_plus import QwenImageEditPlusPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else:
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
index 88d1ce4a46..ed37b238c8 100644
--- a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py
@@ -208,7 +208,6 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
- self.vl_processor = processor
self.tokenizer_max_length = 1024
self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
diff --git a/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
new file mode 100644
index 0000000000..ec203edf16
--- /dev/null
+++ b/src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py
@@ -0,0 +1,883 @@
+# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import math
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import QwenImageLoraLoaderMixin
+from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import QwenImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from PIL import Image
+ >>> from diffusers import QwenImageEditPlusPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = QwenImageEditPlusPipeline.from_pretrained("Qwen/Qwen-Image-Edit-2509", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = (
+ ... "Make Pikachu hold a sign that says 'Qwen Edit is awesome', yarn art style, detailed, vibrant colors"
+ ... )
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
+ >>> # Refer to the pipeline documentation for more details.
+ >>> image = pipe(image, prompt, num_inference_steps=50).images[0]
+ >>> image.save("qwenimage_edit_plus.png")
+ ```
+"""
+
+CONDITION_IMAGE_SIZE = 384 * 384
+VAE_IMAGE_SIZE = 1024 * 1024
+
+
+# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def calculate_dimensions(target_area, ratio):
+ width = math.sqrt(target_area * ratio)
+ height = width / ratio
+
+ width = round(width / 32) * 32
+ height = round(height / 32) * 32
+
+ return width, height
+
+
+class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
+ r"""
+ The Qwen-Image-Edit pipeline for image editing.
+
+ Args:
+ transformer ([`QwenImageTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`Qwen2.5-VL-7B-Instruct`]):
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct), specifically the
+ [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
+ tokenizer (`QwenTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKLQwenImage,
+ text_encoder: Qwen2_5_VLForConditionalGeneration,
+ tokenizer: Qwen2Tokenizer,
+ processor: Qwen2VLProcessor,
+ transformer: QwenImageTransformer2DModel,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ processor=processor,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.latent_channels = self.vae.config.z_dim if getattr(self, "vae", None) else 16
+ # QwenImage latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = 1024
+
+ self.prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
+ self.prompt_template_encode_start_idx = 64
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
+ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
+ bool_mask = mask.bool()
+ valid_lengths = bool_mask.sum(dim=1)
+ selected = hidden_states[bool_mask]
+ split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
+
+ return split_result
+
+ def _get_qwen_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
+ if isinstance(image, list):
+ base_img_prompt = ""
+ for i, img in enumerate(image):
+ base_img_prompt += img_prompt_template.format(i + 1)
+ elif image is not None:
+ base_img_prompt = img_prompt_template.format(1)
+ else:
+ base_img_prompt = ""
+
+ template = self.prompt_template_encode
+
+ drop_idx = self.prompt_template_encode_start_idx
+ txt = [template.format(base_img_prompt + e) for e in prompt]
+
+ model_inputs = self.processor(
+ text=txt,
+ images=image,
+ padding=True,
+ return_tensors="pt",
+ ).to(device)
+
+ outputs = self.text_encoder(
+ input_ids=model_inputs.input_ids,
+ attention_mask=model_inputs.attention_mask,
+ pixel_values=model_inputs.pixel_values,
+ image_grid_thw=model_inputs.image_grid_thw,
+ output_hidden_states=True,
+ )
+
+ hidden_states = outputs.hidden_states[-1]
+ split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
+ split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
+ attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
+ max_seq_len = max([e.size(0) for e in split_hidden_states])
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
+ )
+ encoder_attention_mask = torch.stack(
+ [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
+ )
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ return prompt_embeds, encoder_attention_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ image: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 1024,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ image (`torch.Tensor`, *optional*):
+ image to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
+
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
+
+ return prompt_embeds, prompt_embeds_mask
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_embeds_mask=None,
+ negative_prompt_embeds_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_embeds_mask is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `prompt_embeds_mask` also have to be passed. Make sure to generate `prompt_embeds_mask` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_prompt_embeds_mask` also have to be passed. Make sure to generate `negative_prompt_embeds_mask` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 1024:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
+
+ return latents
+
+ # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ latents_std = (
+ torch.tensor(self.vae.config.latents_std)
+ .view(1, self.latent_channels, 1, 1, 1)
+ .to(image_latents.device, image_latents.dtype)
+ )
+ image_latents = (image_latents - latents_mean) / latents_std
+
+ return image_latents
+
+ def prepare_latents(
+ self,
+ images,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, 1, num_channels_latents, height, width)
+
+ image_latents = None
+ if images is not None:
+ if not isinstance(images, list):
+ images = [images]
+ all_image_latents = []
+ for image in images:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[3:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ all_image_latents.append(image_latents)
+ image_latents = torch.cat(all_image_latents, dim=1)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ true_cfg_scale: float = 4.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: Optional[float] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_embeds_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
+ Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
+ equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
+ enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
+ encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
+ lower image quality.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to None):
+ A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
+ where the guidance scale is applied during inference through noise prediction rescaling, guidance
+ distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
+ scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
+ that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
+ parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
+ ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
+ please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
+ enable classifier-free guidance computations).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] or `tuple`:
+ [`~pipelines.qwenimage.QwenImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is a list with the generated images.
+ """
+ image_size = image[-1].size if isinstance(image, list) else image.size
+ calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
+ height = height or calculated_height
+ width = width or calculated_width
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ negative_prompt_embeds_mask=negative_prompt_embeds_mask,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if not isinstance(image, list):
+ image = [image]
+ condition_image_sizes = []
+ condition_images = []
+ vae_image_sizes = []
+ vae_images = []
+ for img in image:
+ image_width, image_height = img.size
+ condition_width, condition_height = calculate_dimensions(
+ CONDITION_IMAGE_SIZE, image_width / image_height
+ )
+ vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
+ condition_image_sizes.append((condition_width, condition_height))
+ vae_image_sizes.append((vae_width, vae_height))
+ condition_images.append(self.image_processor.resize(img, condition_height, condition_width))
+ vae_images.append(self.image_processor.preprocess(img, vae_height, vae_width).unsqueeze(2))
+
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
+ )
+
+ if true_cfg_scale > 1 and not has_neg_prompt:
+ logger.warning(
+ f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
+ )
+ elif true_cfg_scale <= 1 and has_neg_prompt:
+ logger.warning(
+ " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
+ )
+
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=prompt,
+ prompt_embeds=prompt_embeds,
+ prompt_embeds_mask=prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+ if do_true_cfg:
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
+ image=condition_images,
+ prompt=negative_prompt,
+ prompt_embeds=negative_prompt_embeds,
+ prompt_embeds_mask=negative_prompt_embeds_mask,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents = self.prepare_latents(
+ vae_images,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ img_shapes = [
+ [
+ (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
+ *[
+ (1, vae_height // self.vae_scale_factor // 2, vae_width // self.vae_scale_factor // 2)
+ for vae_width, vae_height in vae_image_sizes
+ ],
+ ]
+ ] * batch_size
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds and guidance_scale is None:
+ raise ValueError("guidance_scale is required for guidance-distilled model.")
+ elif self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
+ logger.warning(
+ f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
+ )
+ guidance = None
+ elif not self.transformer.config.guidance_embeds and guidance_scale is None:
+ guidance = None
+
+ if self.attention_kwargs is None:
+ self._attention_kwargs = {}
+
+ txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
+ negative_txt_seq_lens = (
+ negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
+ )
+
+ # 6. Denoising loop
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=prompt_embeds_mask,
+ encoder_hidden_states=prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ encoder_hidden_states_mask=negative_prompt_embeds_mask,
+ encoder_hidden_states=negative_prompt_embeds,
+ img_shapes=img_shapes,
+ txt_seq_lens=negative_txt_seq_lens,
+ attention_kwargs=self.attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
+ noise_pred = comb_pred * (cond_norm / noise_norm)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ image = self.vae.decode(latents, return_dict=False)[0][:, :, 0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return QwenImagePipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py
index 91684f35f1..d5571ab12f 100644
--- a/src/diffusers/pipelines/sana/__init__.py
+++ b/src/diffusers/pipelines/sana/__init__.py
@@ -26,6 +26,7 @@ else:
_import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"]
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
_import_structure["pipeline_sana_sprint_img2img"] = ["SanaSprintImg2ImgPipeline"]
+ _import_structure["pipeline_sana_video"] = ["SanaVideoPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_sana_controlnet import SanaControlNetPipeline
from .pipeline_sana_sprint import SanaSprintPipeline
from .pipeline_sana_sprint_img2img import SanaSprintImg2ImgPipeline
+ from .pipeline_sana_video import SanaVideoPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/sana/pipeline_output.py b/src/diffusers/pipelines/sana/pipeline_output.py
index f8ac129516..8021b77387 100644
--- a/src/diffusers/pipelines/sana/pipeline_output.py
+++ b/src/diffusers/pipelines/sana/pipeline_output.py
@@ -3,6 +3,7 @@ from typing import List, Union
import numpy as np
import PIL.Image
+import torch
from ...utils import BaseOutput
@@ -19,3 +20,18 @@ class SanaPipelineOutput(BaseOutput):
"""
images: Union[List[PIL.Image.Image], np.ndarray]
+
+
+@dataclass
+class SanaVideoPipelineOutput(BaseOutput):
+ r"""
+ Output class for Sana-Video pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py
index ac979305ca..2beff802c6 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana.py
@@ -1,4 +1,4 @@
-# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index 62b9788292..04f45f817e 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -1,4 +1,4 @@
-# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved.
+# Copyright 2025 SANA-Sprint Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_video.py b/src/diffusers/pipelines/sana/pipeline_sana_video.py
new file mode 100644
index 0000000000..5ec498faff
--- /dev/null
+++ b/src/diffusers/pipelines/sana/pipeline_sana_video.py
@@ -0,0 +1,1017 @@
+# Copyright 2025 SANA-Video Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import re
+import urllib.parse as ul
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SanaLoraLoaderMixin
+from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
+from ...schedulers import DPMSolverMultistepScheduler
+from ...utils import (
+ BACKENDS_MAPPING,
+ USE_PEFT_BACKEND,
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import get_device, is_torch_version, randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SanaVideoPipelineOutput
+
+
+ASPECT_RATIO_480_BIN = {
+ "0.5": [448.0, 896.0],
+ "0.57": [480.0, 832.0],
+ "0.68": [528.0, 768.0],
+ "0.78": [560.0, 720.0],
+ "1.0": [624.0, 624.0],
+ "1.13": [672.0, 592.0],
+ "1.29": [720.0, 560.0],
+ "1.46": [768.0, 528.0],
+ "1.67": [816.0, 496.0],
+ "1.75": [832.0, 480.0],
+ "2.0": [896.0, 448.0],
+}
+
+
+ASPECT_RATIO_720_BIN = {
+ "0.5": [672.0, 1344.0],
+ "0.57": [704.0, 1280.0],
+ "0.68": [800.0, 1152.0],
+ "0.78": [832.0, 1088.0],
+ "1.0": [960.0, 960.0],
+ "1.13": [1024.0, 896.0],
+ "1.29": [1088.0, 832.0],
+ "1.46": [1152.0, 800.0],
+ "1.67": [1248.0, 736.0],
+ "1.75": [1280.0, 704.0],
+ "2.0": [1344.0, 672.0],
+}
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_bs4_available():
+ from bs4 import BeautifulSoup
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import SanaVideoPipeline
+ >>> from diffusers.utils import export_to_video
+
+ >>> model_id = "Efficient-Large-Model/SANA-Video_2B_480p_diffusers"
+ >>> pipe = SanaVideoPipeline.from_pretrained(model_id)
+ >>> pipe.transformer.to(torch.bfloat16)
+ >>> pipe.text_encoder.to(torch.bfloat16)
+ >>> pipe.vae.to(torch.float32)
+ >>> pipe.to("cuda")
+ >>> model_score = 30
+
+ >>> prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest, golden light glimmers on his hair as sunlight filters through the leaves. He wears a light shirt, wind gently blowing his hair and collar, light dances across his face with his movements. The background is blurred, with dappled light and soft tree shadows in the distance. The camera focuses on his lifted gaze, clear and emotional."
+ >>> negative_prompt = "A chaotic sequence with misshapen, deformed limbs in heavy motion blur, sudden disappearance, jump cuts, jerky movements, rapid shot changes, frames out of sync, inconsistent character shapes, temporal artifacts, jitter, and ghosting effects, creating a disorienting visual experience."
+ >>> motion_prompt = f" motion score: {model_score}."
+ >>> prompt = prompt + motion_prompt
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... negative_prompt=negative_prompt,
+ ... height=480,
+ ... width=832,
+ ... frames=81,
+ ... guidance_scale=6,
+ ... num_inference_steps=50,
+ ... generator=torch.Generator(device="cuda").manual_seed(42),
+ ... ).frames[0]
+
+ >>> export_to_video(output, "sana-video-output.mp4", fps=16)
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
+ r"""
+ Pipeline for text-to-video generation using [Sana](https://huggingface.co/papers/2509.24695). This model inherits
+ from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all
+ pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`GemmaTokenizer`] or [`GemmaTokenizerFast`]):
+ The tokenizer used to tokenize the prompt.
+ text_encoder ([`Gemma2PreTrainedModel`]):
+ Text encoder model to encode the input prompts.
+ vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer ([`SanaVideoTransformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`DPMSolverMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
+ """
+
+ # fmt: off
+ bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}")
+ # fmt: on
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
+ text_encoder: Gemma2PreTrainedModel,
+ vae: Union[AutoencoderDC, AutoencoderKLWan],
+ transformer: SanaVideoTransformer3DModel,
+ scheduler: DPMSolverMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
+ )
+
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
+
+ self.vae_scale_factor = self.vae_scale_factor_spatial
+
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: torch.device,
+ dtype: torch.dtype,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
+
+ # prepare complex human instruction
+ if not complex_human_instruction:
+ max_length_all = max_sequence_length
+ else:
+ chi_prompt = "\n".join(complex_human_instruction)
+ prompt = [chi_prompt + p for p in prompt]
+ num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
+ max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_length_all,
+ truncation=True,
+ add_special_tokens=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ prompt_attention_mask = text_inputs.attention_mask
+ prompt_attention_mask = prompt_attention_mask.to(device)
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
+ prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
+
+ return prompt_embeds, prompt_attention_mask
+
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: str = "",
+ num_videos_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ clean_caption: bool = False,
+ max_sequence_length: int = 300,
+ complex_human_instruction: Optional[List[str]] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the video generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ PixArt-Alpha, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ number of videos that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
+ clean_caption (`bool`, defaults to `False`):
+ If `True`, the function will preprocess and clean the provided caption before encoding.
+ max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
+ complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
+ If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
+ the prompt.
+ """
+
+ if device is None:
+ device = self._execution_device
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ else:
+ dtype = None
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ # See Section 3.1. of the paper.
+ max_length = max_sequence_length
+ select_index = [0] + list(range(-max_length + 1, 0))
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ )
+
+ prompt_embeds = prompt_embeds[:, select_index]
+ prompt_attention_mask = prompt_attention_mask[:, select_index]
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ dtype=dtype,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=False,
+ )
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1)
+ else:
+ negative_prompt_embeds = None
+ negative_prompt_attention_mask = None
+
+ if self.text_encoder is not None:
+ if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs=None,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ ):
+ if height % 32 != 0 or width % 32 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
+ def _text_preprocessing(self, text, clean_caption=False):
+ if clean_caption and not is_bs4_available():
+ logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if clean_caption and not is_ftfy_available():
+ logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
+ logger.warning("Setting `clean_caption` to False...")
+ clean_caption = False
+
+ if not isinstance(text, (tuple, list)):
+ text = [text]
+
+ def process(text: str):
+ if clean_caption:
+ text = self._clean_caption(text)
+ text = self._clean_caption(text)
+ else:
+ text = text.lower().strip()
+ return text
+
+ return [process(t) for t in text]
+
+ # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
+ def _clean_caption(self, caption):
+ caption = str(caption)
+ caption = ul.unquote_plus(caption)
+ caption = caption.strip().lower()
+ caption = re.sub("", "person", caption)
+ # urls:
+ caption = re.sub(
+ r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ caption = re.sub(
+ r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
+ "",
+ caption,
+ ) # regex for urls
+ # html:
+ caption = BeautifulSoup(caption, features="html.parser").text
+
+ # @
+ caption = re.sub(r"@[\w\d]+\b", "", caption)
+
+ # 31C0—31EF CJK Strokes
+ # 31F0—31FF Katakana Phonetic Extensions
+ # 3200—32FF Enclosed CJK Letters and Months
+ # 3300—33FF CJK Compatibility
+ # 3400—4DBF CJK Unified Ideographs Extension A
+ # 4DC0—4DFF Yijing Hexagram Symbols
+ # 4E00—9FFF CJK Unified Ideographs
+ caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
+ caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
+ caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
+ caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
+ caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
+ caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
+ caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
+ #######################################################
+
+ # все виды тире / all types of dash --> "-"
+ caption = re.sub(
+ r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
+ "-",
+ caption,
+ )
+
+ # кавычки к одному стандарту
+ caption = re.sub(r"[`´«»“”¨]", '"', caption)
+ caption = re.sub(r"[‘’]", "'", caption)
+
+ # "
+ caption = re.sub(r""?", "", caption)
+ # &
+ caption = re.sub(r"&", "", caption)
+
+ # ip addresses:
+ caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
+
+ # article ids:
+ caption = re.sub(r"\d:\d\d\s+$", "", caption)
+
+ # \n
+ caption = re.sub(r"\\n", " ", caption)
+
+ # "#123"
+ caption = re.sub(r"#\d{1,3}\b", "", caption)
+ # "#12345.."
+ caption = re.sub(r"#\d{5,}\b", "", caption)
+ # "123456.."
+ caption = re.sub(r"\b\d{6,}\b", "", caption)
+ # filenames:
+ caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
+
+ #
+ caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
+ caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
+
+ caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
+ caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
+
+ # this-is-my-cute-cat / this_is_my_cute_cat
+ regex2 = re.compile(r"(?:\-|\_)")
+ if len(re.findall(regex2, caption)) > 3:
+ caption = re.sub(regex2, " ", caption)
+
+ caption = ftfy.fix_text(caption)
+ caption = html.unescape(html.unescape(caption))
+
+ caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
+ caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
+ caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
+
+ caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
+ caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
+ caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
+ caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
+ caption = re.sub(r"\bpage\s+\d+\b", "", caption)
+
+ caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
+
+ caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
+
+ caption = re.sub(r"\b\s+\:\s+", r": ", caption)
+ caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
+ caption = re.sub(r"\s+", " ", caption)
+
+ caption.strip()
+
+ caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
+ caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
+ caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
+ caption = re.sub(r"^\.\S+$", "", caption)
+
+ return caption.strip()
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: str = "",
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ sigmas: List[float] = None,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ height: int = 480,
+ width: int = 832,
+ frames: int = 81,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ clean_caption: bool = False,
+ use_resolution_binning: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 300,
+ complex_human_instruction: List[str] = [
+ "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for video generation. Evaluate the level of detail in the user prompt:",
+ "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, motion, and temporal relationships to create vivid and dynamic scenes.",
+ "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
+ "Here are examples of how to transform or refine prompts:",
+ "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat slowly settling into a curled position, peacefully falling asleep on a warm sunny windowsill, with gentle sunlight filtering through surrounding pots of blooming red flowers.",
+ "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps gradually lighting up, a diverse crowd of people in colorful clothing walking past, and a double-decker bus smoothly passing by towering glass skyscrapers.",
+ "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
+ "User Prompt: ",
+ ],
+ ) -> Union[SanaVideoPipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate videos that are closely linked to
+ the text `prompt`, usually at the expense of lower video quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of videos to generate per prompt.
+ height (`int`, *optional*, defaults to 480):
+ The height in pixels of the generated video.
+ width (`int`, *optional*, defaults to 832):
+ The width in pixels of the generated video.
+ frames (`int`, *optional*, defaults to 81):
+ The number of frames in the generated video.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will be generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated video. Choose between mp4 or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SanaVideoPipelineOutput`] instead of a plain tuple.
+ attention_kwargs:
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ clean_caption (`bool`, *optional*, defaults to `True`):
+ Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
+ be installed. If the dependencies are not installed, the embeddings will be created from the raw
+ prompt.
+ use_resolution_binning (`bool` defaults to `True`):
+ If set to `True`, the requested height and width are first mapped to the closest resolutions using
+ `ASPECT_RATIO_480_BIN` or `ASPECT_RATIO_720_BIN`. After the produced latents are decoded into videos,
+ they are resized back to the requested resolution. Useful for generating non-square videos.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to `300`):
+ Maximum sequence length to use with the `prompt`.
+ complex_human_instruction (`List[str]`, *optional*):
+ Instructions for complex human attention:
+ https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaVideoPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated videos
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ if use_resolution_binning:
+ if self.transformer.config.sample_size == 30:
+ aspect_ratio_bin = ASPECT_RATIO_480_BIN
+ elif self.transformer.config.sample_size == 22:
+ aspect_ratio_bin = ASPECT_RATIO_720_BIN
+ else:
+ raise ValueError("Invalid sample size")
+ orig_height, orig_width = height, width
+ height, width = self.video_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
+
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ callback_on_step_end_tensor_inputs,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_attention_mask,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._interrupt = False
+
+ # 2. Default height and width to transformer
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ clean_caption=clean_caption,
+ max_sequence_length=max_sequence_length,
+ complex_human_instruction=complex_human_instruction,
+ lora_scale=lora_scale,
+ )
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+ prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
+
+ # 4. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
+ )
+
+ # 5. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ latent_channels,
+ height,
+ width,
+ frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ transformer_dtype = self.transformer.dtype
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timestep = t.expand(latent_model_input.shape[0])
+
+ # predict noise model_output
+ noise_pred = self.transformer(
+ latent_model_input.to(dtype=transformer_dtype),
+ encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype),
+ encoder_attention_mask=prompt_attention_mask,
+ timestep=timestep,
+ return_dict=False,
+ attention_kwargs=self.attention_kwargs,
+ )[0]
+ noise_pred = noise_pred.float()
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # learned sigma
+ if self.transformer.config.out_channels // 2 == latent_channels:
+ noise_pred = noise_pred.chunk(2, dim=1)[0]
+
+ # compute previous image: x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if output_type == "latent":
+ video = latents
+ else:
+ latents = latents.to(self.vae.dtype)
+ torch_accelerator_module = getattr(torch, get_device(), torch.cuda)
+ oom_error = (
+ torch.OutOfMemoryError
+ if is_torch_version(">=", "2.5.0")
+ else torch_accelerator_module.OutOfMemoryError
+ )
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ try:
+ video = self.vae.decode(latents, return_dict=False)[0]
+ except oom_error as e:
+ warnings.warn(
+ f"{e}. \n"
+ f"Try to use VAE tiling for large images. For example: \n"
+ f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
+ )
+
+ if use_resolution_binning:
+ video = self.video_processor.resize_and_crop_tensor(video, orig_width, orig_height)
+
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SanaVideoPipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
index a5f67bffe6..49b09e205c 100644
--- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
+++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py
@@ -48,8 +48,8 @@ class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`Q16SafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
- Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
- about a model's potential harms.
+ Please refer to the [model card](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) for
+ more details about a model's potential harms.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
"""
@@ -332,7 +332,7 @@ class SemanticStableDiffusionPipeline(DeprecatedPipelineMixin, DiffusionPipeline
>>> from diffusers import SemanticStableDiffusionPipeline
>>> pipe = SemanticStableDiffusionPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
... )
>>> pipe = pipe.to("cuda")
diff --git a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
index 89d4d2dca5..07b382dfc4 100644
--- a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
+++ b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py
@@ -18,7 +18,6 @@ from typing import Optional
import torch
import torch.nn as nn
-import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
index aa39983c4e..a6a60ad94b 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
@@ -21,7 +21,7 @@ from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
@@ -55,7 +55,7 @@ EXAMPLE_DOC_STRING = """
"""
-class StableCascadeDecoderPipeline(DiffusionPipeline):
+class StableCascadeDecoderPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for generating images from the Stable Cascade model.
@@ -79,6 +79,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
width=int(24*10.67)=256 in order to match the training conditions.
"""
+ _last_supported_version = "0.35.2"
+
unet_name = "decoder"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->decoder->vqgan"
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
index b3dc23f2e5..838b93faaa 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py
@@ -20,7 +20,7 @@ from transformers import CLIPImageProcessor, CLIPTextModelWithProjection, CLIPTo
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import is_torch_version, replace_example_docstring
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
@@ -42,7 +42,7 @@ TEXT2IMAGE_EXAMPLE_DOC_STRING = """
"""
-class StableCascadeCombinedPipeline(DiffusionPipeline):
+class StableCascadeCombinedPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Stable Cascade.
@@ -74,6 +74,8 @@ class StableCascadeCombinedPipeline(DiffusionPipeline):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
"""
+ _last_supported_version = "0.35.2"
+
_load_connected_pipes = True
_optional_components = ["prior_feature_extractor", "prior_image_encoder"]
diff --git a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
index 9e63b3489c..29ad8b5429 100644
--- a/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
+++ b/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
@@ -25,7 +25,7 @@ from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
-from ..pipeline_utils import DiffusionPipeline
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline
if is_torch_xla_available():
@@ -77,7 +77,7 @@ class StableCascadePriorPipelineOutput(BaseOutput):
negative_prompt_embeds_pooled: Union[torch.Tensor, np.ndarray]
-class StableCascadePriorPipeline(DiffusionPipeline):
+class StableCascadePriorPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for generating image prior for Stable Cascade.
@@ -103,6 +103,8 @@ class StableCascadePriorPipeline(DiffusionPipeline):
Default resolution for multiple images generated.
"""
+ _last_supported_version = "0.35.2"
+
unet_name = "prior"
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "image_encoder->text_encoder->prior"
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
index 1afa7698da..6befe77aa4 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
@@ -349,12 +349,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
index 78e3ba239c..81656beba7 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
@@ -389,12 +389,8 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
Examples:
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
index b7e17ba681..5938fe232a 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py
@@ -103,11 +103,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
r"""
Flax-based pipeline for text-guided image inpainting using Stable Diffusion.
-
-
- 🧪 This is an experimental feature!
-
-
+ > [!WARNING] > 🧪 This is an experimental feature!
This model inherits from [`FlaxDiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -435,12 +431,8 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
-
-
- This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
- future release.
-
-
+ > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
+ removed in a > future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index 1618f89a49..660d9801df 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -248,7 +248,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -336,7 +336,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
index 7e97909f42..9b11bc8781 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py
@@ -272,7 +272,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -361,7 +361,7 @@ class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
index bed596e57c..b947cbff09 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py
@@ -278,7 +278,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
return torch.zeros(
(
batch_size * num_images_per_prompt,
- self.tokenizer_max_length,
+ max_sequence_length,
self.transformer.config.joint_attention_dim,
),
device=device,
@@ -367,7 +367,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
return prompt_embeds, pooled_prompt_embeds
@@ -1247,7 +1247,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
# match the inpainting pipeline and will be updated with input + mask inpainting model later
if num_channels_transformer == 33:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if (
diff --git a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
index 87bd9f4444..65c25ffbe4 100644
--- a/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
+++ b/src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py
@@ -249,11 +249,7 @@ class StableDiffusionDiffEditPipeline(
StableDiffusionLoraLoaderMixin,
):
r"""
-
-
- This is an experimental feature!
-
-
+ > [!WARNING] > This is an experimental feature!
Pipeline for text-guided image inpainting using Stable Diffusion and DiffEdit.
diff --git a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
index df2564a89b..feebd6adf8 100755
--- a/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py
@@ -81,11 +81,7 @@ class StableDiffusionKDiffusionPipeline(
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
-
-
- This is an experimental pipeline and is likely to change in the future.
-
-
+ > [!WARNING] > This is an experimental pipeline and is likely to change in the future.
Args:
vae ([`AutoencoderKL`]):
diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
index 18f8536a75..88cc7515b0 100644
--- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py
@@ -1501,7 +1501,7 @@ class StableDiffusionXLInpaintPipeline(
# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
- # default case for runwayml/stable-diffusion-inpainting
+ # default case for stable-diffusion-v1-5/stable-diffusion-inpainting
num_channels_mask = mask.shape[1]
num_channels_masked_image = masked_image_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_vace.py b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
index eab1aacfc5..63e557a98f 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_vace.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_vace.py
@@ -152,34 +152,36 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
- transformer ([`WanVACETransformer3DModel`]):
- Conditional Transformer to denoise the input latents.
- transformer_2 ([`WanVACETransformer3DModel`], *optional*):
- Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
- `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
- `transformer` is used.
- scheduler ([`UniPCMultistepScheduler`]):
- A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ transformer ([`WanVACETransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the high-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
+ `transformer` or `transformer_2` must be provided.
+ transformer_2 ([`WanVACETransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. At least one of
+ `transformer` or `transformer_2` must be provided.
boundary_ratio (`float`, *optional*, defaults to `None`):
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
- boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
+ boundary_timestep. If `None`, only the available transformer is used for the entire denoising process.
"""
- model_cpu_offload_seq = "text_encoder->transformer->vae"
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
- _optional_components = ["transformer_2"]
+ _optional_components = ["transformer", "transformer_2"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
- transformer: WanVACETransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer: WanVACETransformer3DModel = None,
transformer_2: WanVACETransformer3DModel = None,
boundary_ratio: Optional[float] = None,
):
@@ -336,7 +338,15 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
reference_images=None,
guidance_scale_2=None,
):
- base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
+ if self.transformer is not None:
+ base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
+ elif self.transformer_2 is not None:
+ base = self.vae_scale_factor_spatial * self.transformer_2.config.patch_size[1]
+ else:
+ raise ValueError(
+ "`transformer` or `transformer_2` component must be set in order to run inference with this pipeline"
+ )
+
if height % base != 0 or width % base != 0:
raise ValueError(f"`height` and `width` have to be divisible by {base} but are {height} and {width}.")
@@ -414,7 +424,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
device: Optional[torch.device] = None,
):
if video is not None:
- base = self.vae_scale_factor_spatial * self.transformer.config.patch_size[1]
+ base = self.vae_scale_factor_spatial * (
+ self.transformer.config.patch_size[1]
+ if self.transformer is not None
+ else self.transformer_2.config.patch_size[1]
+ )
video_height, video_width = self.video_processor.get_default_height_width(video[0])
if video_height * video_width > height * width:
@@ -589,7 +603,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
"Generating with more than one video is not yet supported. This may be supported in the future."
)
- transformer_patch_size = self.transformer.config.patch_size[1]
+ transformer_patch_size = (
+ self.transformer.config.patch_size[1]
+ if self.transformer is not None
+ else self.transformer_2.config.patch_size[1]
+ )
mask_list = []
for mask_, reference_images_batch in zip(mask, reference_images):
@@ -795,7 +813,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# Simplification of implementation for now
- if not isinstance(prompt, str):
+ if prompt is not None and not isinstance(prompt, str):
raise ValueError("Passing a list of prompts is not yet supported. This may be supported in the future.")
if num_videos_per_prompt != 1:
raise ValueError(
@@ -844,20 +862,25 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
vae_dtype = self.vae.dtype
- transformer_dtype = self.transformer.dtype
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
+ vace_layers = (
+ self.transformer.config.vace_layers
+ if self.transformer is not None
+ else self.transformer_2.config.vace_layers
+ )
if isinstance(conditioning_scale, (int, float)):
- conditioning_scale = [conditioning_scale] * len(self.transformer.config.vace_layers)
+ conditioning_scale = [conditioning_scale] * len(vace_layers)
if isinstance(conditioning_scale, list):
- if len(conditioning_scale) != len(self.transformer.config.vace_layers):
+ if len(conditioning_scale) != len(vace_layers):
raise ValueError(
- f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(self.transformer.config.vace_layers)}."
+ f"Length of `conditioning_scale` {len(conditioning_scale)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = torch.tensor(conditioning_scale)
if isinstance(conditioning_scale, torch.Tensor):
- if conditioning_scale.size(0) != len(self.transformer.config.vace_layers):
+ if conditioning_scale.size(0) != len(vace_layers):
raise ValueError(
- f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(self.transformer.config.vace_layers)}."
+ f"Length of `conditioning_scale` {conditioning_scale.size(0)} does not match number of layers {len(vace_layers)}."
)
conditioning_scale = conditioning_scale.to(device=device, dtype=transformer_dtype)
@@ -900,7 +923,11 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
conditioning_latents = torch.cat([conditioning_latents, mask], dim=1)
conditioning_latents = conditioning_latents.to(transformer_dtype)
- num_channels_latents = self.transformer.config.in_channels
+ num_channels_latents = (
+ self.transformer.config.in_channels
+ if self.transformer is not None
+ else self.transformer_2.config.in_channels
+ )
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -968,7 +995,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
index 1a2d2e9c22..a976126da7 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py
@@ -49,7 +49,7 @@ EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
- >>> from diffusers.utils import export_to_video
+ >>> from diffusers.utils import export_to_video, load_video
>>> from diffusers import AutoencoderKLWan, WanVideoToVideoPipeline
>>> from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py
index bf85795651..5dd8f56717 100644
--- a/src/diffusers/quantizers/quantization_config.py
+++ b/src/diffusers/quantizers/quantization_config.py
@@ -21,19 +21,20 @@ https://github.com/huggingface/transformers/blob/52cb4034ada381fe1ffe8d428a1076e
"""
import copy
+import dataclasses
import importlib.metadata
import inspect
import json
import os
import warnings
-from dataclasses import dataclass
+from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
from packaging import version
-from ..utils import is_torch_available, is_torchao_available, logging
+from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -443,7 +444,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Args:
- quant_type (`str`):
+ quant_type (Union[`str`, AOBaseConfig]):
The type of quantization we want to use, currently supporting:
- **Integer quantization:**
- Full function names: `int4_weight_only`, `int8_dynamic_activation_int4_weight`,
@@ -465,6 +466,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- **Unsigned Integer quantization:**
- Full function names: `uintx_weight_only`
- Shorthands: `uint1wo`, `uint2wo`, `uint3wo`, `uint4wo`, `uint5wo`, `uint6wo`, `uint7wo`
+ - An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`List[str]`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision.
@@ -478,6 +480,12 @@ class TorchAoConfig(QuantizationConfigMixin):
```python
from diffusers import FluxTransformer2DModel, TorchAoConfig
+ # AOBaseConfig-based configuration
+ from torchao.quantization import Int8WeightOnlyConfig
+
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+
+ # String-based config
quantization_config = TorchAoConfig("int8wo")
transformer = FluxTransformer2DModel.from_pretrained(
"black-forest-labs/Flux.1-Dev",
@@ -490,7 +498,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def __init__(
self,
- quant_type: str,
+ quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List[str]] = None,
**kwargs,
) -> None:
@@ -504,34 +512,103 @@ class TorchAoConfig(QuantizationConfigMixin):
else:
self.quant_type_kwargs = kwargs
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
- is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
- if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ self.post_init()
+
+ def post_init(self):
+ if not isinstance(self.quant_type, str):
+ if is_torchao_version("<=", "0.9.0"):
raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
- f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
+ f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
)
- raise ValueError(
- f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
- f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
- )
+ from torchao.quantization.quant_api import AOBaseConfig
- method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
- signature = inspect.signature(method)
- all_kwargs = {
- param.name
- for param in signature.parameters.values()
- if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
- }
- unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+ if not isinstance(self.quant_type, AOBaseConfig):
+ raise TypeError(f"quant_type must be a AOBaseConfig instance, got {type(self.quant_type).__name__}")
- if len(unsupported_kwargs) > 0:
- raise ValueError(
- f'The quantization method "{quant_type}" does not support the following keyword arguments: '
- f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
- )
+ elif isinstance(self.quant_type, str):
+ TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
+
+ if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
+ is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
+ if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
+ )
+
+ raise ValueError(
+ f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
+ f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
+ )
+
+ method = TORCHAO_QUANT_TYPE_METHODS[self.quant_type]
+ signature = inspect.signature(method)
+ all_kwargs = {
+ param.name
+ for param in signature.parameters.values()
+ if param.kind in [inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD]
+ }
+ unsupported_kwargs = list(self.quant_type_kwargs.keys() - all_kwargs)
+
+ if len(unsupported_kwargs) > 0:
+ raise ValueError(
+ f'The quantization method "{self.quant_type}" does not support the following keyword arguments: '
+ f"{unsupported_kwargs}. The following keywords arguments are supported: {all_kwargs}."
+ )
+
+ def to_dict(self):
+ """Convert configuration to a dictionary."""
+ d = super().to_dict()
+
+ if isinstance(self.quant_type, str):
+ # Handle layout serialization if present
+ if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
+ if is_dataclass(d["quant_type_kwargs"]["layout"]):
+ d["quant_type_kwargs"]["layout"] = [
+ d["quant_type_kwargs"]["layout"].__class__.__name__,
+ dataclasses.asdict(d["quant_type_kwargs"]["layout"]),
+ ]
+ if isinstance(d["quant_type_kwargs"]["layout"], list):
+ assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layout kwargs"
+ assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string"
+ assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict"
+ else:
+ raise ValueError("layout must be a list")
+ else:
+ # Handle AOBaseConfig serialization
+ from torchao.core.config import config_to_dict
+
+ # For now we assume there is 1 config per Transformer, however in the future
+ # We may want to support a config per fqn.
+ d["quant_type"] = {"default": config_to_dict(self.quant_type)}
+
+ return d
+
+ @classmethod
+ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
+ """Create configuration from a dictionary."""
+ if not is_torchao_version(">", "0.9.0"):
+ raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
+ config_dict = config_dict.copy()
+ quant_type = config_dict.pop("quant_type")
+
+ if isinstance(quant_type, str):
+ return cls(quant_type=quant_type, **config_dict)
+ # Check if we only have one key which is "default"
+ # In the future we may update this
+ assert len(quant_type) == 1 and "default" in quant_type, (
+ "Expected only one key 'default' in quant_type dictionary"
+ )
+ quant_type = quant_type["default"]
+
+ # Deserialize quant_type if needed
+ from torchao.core.config import config_from_dict
+
+ quant_type = config_from_dict(quant_type)
+
+ return cls(quant_type=quant_type, **config_dict)
@classmethod
def _get_torchao_quant_type_to_method(cls):
@@ -681,8 +758,38 @@ class TorchAoConfig(QuantizationConfigMixin):
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self):
- TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
- return TORCHAO_QUANT_TYPE_METHODS[self.quant_type](**self.quant_type_kwargs)
+ """Create the appropriate quantization method based on configuration."""
+ if not isinstance(self.quant_type, str):
+ return self.quant_type
+ else:
+ methods = self._get_torchao_quant_type_to_method()
+ quant_type_kwargs = self.quant_type_kwargs.copy()
+ if (
+ not torch.cuda.is_available()
+ and is_torchao_available()
+ and self.quant_type == "int4_weight_only"
+ and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
+ and quant_type_kwargs.get("layout", None) is None
+ ):
+ if torch.xpu.is_available():
+ if version.parse(importlib.metadata.version("torchao")) >= version.parse(
+ "0.11.0"
+ ) and version.parse(importlib.metadata.version("torch")) > version.parse("2.7.9"):
+ from torchao.dtypes import Int4XPULayout
+ from torchao.quantization.quant_primitives import ZeroPointDomain
+
+ quant_type_kwargs["layout"] = Int4XPULayout()
+ quant_type_kwargs["zero_point_domain"] = ZeroPointDomain.INT
+ else:
+ raise ValueError(
+ "TorchAoConfig requires torchao >= 0.11.0 and torch >= 2.8.0 for XPU support. Please upgrade the version or use run on CPU with the cpu version pytorch."
+ )
+ else:
+ from torchao.dtypes import Int4CPULayout
+
+ quant_type_kwargs["layout"] = Int4CPULayout()
+
+ return methods[self.quant_type](**quant_type_kwargs)
def __repr__(self):
r"""
diff --git a/src/diffusers/quantizers/torchao/torchao_quantizer.py b/src/diffusers/quantizers/torchao/torchao_quantizer.py
index 976bc8a1e0..2334c7af86 100644
--- a/src/diffusers/quantizers/torchao/torchao_quantizer.py
+++ b/src/diffusers/quantizers/torchao/torchao_quantizer.py
@@ -18,9 +18,10 @@ https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac17
"""
import importlib
+import re
import types
from fnmatch import fnmatch
-from typing import TYPE_CHECKING, Any, Dict, List, Union
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version
@@ -107,6 +108,21 @@ if (
_update_torch_safe_globals()
+def fuzzy_match_size(config_name: str) -> Optional[str]:
+ """
+ Extract the size digit from strings like "4weight", "8weight". Returns the digit as an integer if found, otherwise
+ None.
+ """
+ config_name = config_name.lower()
+
+ str_match = re.search(r"(\d)weight", config_name)
+
+ if str_match:
+ return str_match.group(1)
+
+ return None
+
+
logger = logging.get_logger(__name__)
@@ -176,8 +192,7 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def update_torch_dtype(self, torch_dtype):
quant_type = self.quantization_config.quant_type
-
- if quant_type.startswith("int") or quant_type.startswith("uint"):
+ if isinstance(quant_type, str) and (quant_type.startswith("int") or quant_type.startswith("uint")):
if torch_dtype is not None and torch_dtype != torch.bfloat16:
logger.warning(
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "
@@ -197,24 +212,44 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
quant_type = self.quantization_config.quant_type
+ from accelerate.utils import CustomDtype
- if quant_type.startswith("int8") or quant_type.startswith("int4"):
- # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
- return torch.int8
- elif quant_type == "uintx_weight_only":
- return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
- elif quant_type.startswith("uint"):
- return {
- 1: torch.uint1,
- 2: torch.uint2,
- 3: torch.uint3,
- 4: torch.uint4,
- 5: torch.uint5,
- 6: torch.uint6,
- 7: torch.uint7,
- }[int(quant_type[4])]
- elif quant_type.startswith("float") or quant_type.startswith("fp"):
- return torch.bfloat16
+ if isinstance(quant_type, str):
+ if quant_type.startswith("int8"):
+ # Note that int4 weights are created by packing into torch.int8, but since there is no torch.int4, we use torch.int8
+ return torch.int8
+ elif quant_type.startswith("int4"):
+ return CustomDtype.INT4
+ elif quant_type == "uintx_weight_only":
+ return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
+ elif quant_type.startswith("uint"):
+ return {
+ 1: torch.uint1,
+ 2: torch.uint2,
+ 3: torch.uint3,
+ 4: torch.uint4,
+ 5: torch.uint5,
+ 6: torch.uint6,
+ 7: torch.uint7,
+ }[int(quant_type[4])]
+ elif quant_type.startswith("float") or quant_type.startswith("fp"):
+ return torch.bfloat16
+
+ elif is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ # Map the extracted digit to appropriate dtype
+ if size_digit == "4":
+ return CustomDtype.INT4
+ else:
+ # Default to int8
+ return torch.int8
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION):
return target_dtype
@@ -297,6 +332,21 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
# Original mapping for non-AOBaseConfig types
# For the uint types, this is a best guess. Once these types become more used
# we can look into their nuances.
+ if is_torchao_version(">", "0.9.0"):
+ from torchao.core.config import AOBaseConfig
+
+ quant_type = self.quantization_config.quant_type
+ # For autoquant case, it will be treated in the string implementation below in map_to_target_dtype
+ if isinstance(quant_type, AOBaseConfig):
+ # Extract size digit using fuzzy match on the class name
+ config_name = quant_type.__class__.__name__
+ size_digit = fuzzy_match_size(config_name)
+
+ if size_digit == "4":
+ return 8
+ else:
+ return 4
+
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4}
quant_type = self.quantization_config.quant_type
for pattern, target_dtype in map_to_target_dtype.items():
diff --git a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
index 6b968e7081..9206ee80a6 100644
--- a/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
+++ b/src/diffusers/schedulers/deprecated/scheduling_karras_ve.py
@@ -53,13 +53,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
-
-
- For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid search
- values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in Table 5 of
- the paper.
-
-
+ > [!TIP] > For more details on the parameters, see [Appendix E](https://huggingface.co/papers/2206.00364). The grid
+ search > values used to find the optimal `{s_noise, s_churn, s_min, s_max}` for a specific model are described in
+ Table 5 of > the paper.
Args:
sigma_min (`float`, defaults to 0.02):
diff --git a/src/diffusers/schedulers/scheduling_consistency_models.py b/src/diffusers/schedulers/scheduling_consistency_models.py
index 0f50622588..5d81d5eb8a 100644
--- a/src/diffusers/schedulers/scheduling_consistency_models.py
+++ b/src/diffusers/schedulers/scheduling_consistency_models.py
@@ -268,11 +268,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
Gets the scalings used in the consistency model parameterization (from Appendix C of the
[paper](https://huggingface.co/papers/2303.01469)) to enforce boundary condition.
-
-
- `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
-
-
+ > [!TIP] > `epsilon` in the equations for `c_skip` and `c_out` is set to `sigma_min`.
Args:
sigma (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
index 66ed296da8..b9567f2c47 100644
--- a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
@@ -304,12 +304,8 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index d07ff8b200..8b523cd13f 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -630,12 +630,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
index 9ec9588511..f1a1ac3d82 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py
@@ -491,12 +491,8 @@ class DPMSolverMultistepInverseScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 8663210a62..1ae8249730 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -568,12 +568,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
index f1b38aaff5..e9ba695e1f 100644
--- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py
@@ -370,12 +370,8 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an
integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise
- prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both
+ noise > prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_sasolver.py b/src/diffusers/schedulers/scheduling_sasolver.py
index 2df7d560dd..2979ce193a 100644
--- a/src/diffusers/schedulers/scheduling_sasolver.py
+++ b/src/diffusers/schedulers/scheduling_sasolver.py
@@ -500,12 +500,8 @@ class SASolverScheduler(SchedulerMixin, ConfigMixin):
Noise_prediction is designed to discretize an integral of the noise prediction model, and data_prediction is
designed to discretize an integral of the data prediction model.
-
-
- The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction for both
- noise prediction and data prediction models.
-
-
+ > [!TIP] > The algorithm and model type are decoupled. You can use either data_prediction or noise_prediction
+ for both > noise prediction and data prediction models.
Args:
model_output (`torch.Tensor`):
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index f0e162ea6b..a355c7bb1a 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -138,15 +138,11 @@ class SchedulerMixin(PushToHubMixin):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
-
-
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
- auth login`. You can also activate the special
- ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ > [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
+ with `hf > auth login`. You can also activate the special >
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a >
firewalled environment.
-
-
"""
config, kwargs, commit_hash = cls.load_config(
pretrained_model_name_or_path=pretrained_model_name_or_path,
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
index ffbe3b9020..0534e47d8a 100644
--- a/src/diffusers/schedulers/scheduling_utils_flax.py
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -120,19 +120,12 @@ class FlaxSchedulerMixin(PushToHubMixin):
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
-
+ > [!TIP] > It is required to be logged in (`hf auth login`) when you want to use private or [gated >
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
- It is required to be logged in (`hf auth login`) when you want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
-
-
-
- Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
- use this method in a firewalled environment.
-
-
+ > [!TIP] > Activate the special
+ ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to > use this method in a
+ firewalled environment.
"""
logger.warning(
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 63932221b2..cf77aaee82 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -38,7 +38,7 @@ from .constants import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
-from .deprecation_utils import deprecate
+from .deprecation_utils import _maybe_remap_transformers_class, deprecate
from .doc_utils import replace_example_docstring
from .dynamic_modules_utils import get_class_from_dynamic_module
from .export_utils import export_to_gif, export_to_obj, export_to_ply, export_to_video
@@ -64,6 +64,8 @@ from .import_utils import (
get_objects_from_module,
is_accelerate_available,
is_accelerate_version,
+ is_aiter_available,
+ is_aiter_version,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 8b4d76f3cb..42a53e1810 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -45,7 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
-DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
+DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
diff --git a/src/diffusers/utils/deprecation_utils.py b/src/diffusers/utils/deprecation_utils.py
index 4f001b3047..d76623541b 100644
--- a/src/diffusers/utils/deprecation_utils.py
+++ b/src/diffusers/utils/deprecation_utils.py
@@ -4,6 +4,54 @@ from typing import Any, Dict, Optional, Union
from packaging import version
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
+# Mapping for deprecated Transformers classes to their replacements
+# This is used to handle models that reference deprecated class names in their configs
+# Reference: https://github.com/huggingface/transformers/issues/40822
+# Format: {
+# "DeprecatedClassName": {
+# "new_class": "NewClassName",
+# "transformers_version": (">=", "5.0.0"), # (operation, version) tuple
+# }
+# }
+_TRANSFORMERS_CLASS_REMAPPING = {
+ "CLIPFeatureExtractor": {
+ "new_class": "CLIPImageProcessor",
+ "transformers_version": (">", "4.57.0"),
+ },
+}
+
+
+def _maybe_remap_transformers_class(class_name: str) -> Optional[str]:
+ """
+ Check if a Transformers class should be remapped to a newer version.
+
+ Args:
+ class_name: The name of the class to check
+
+ Returns:
+ The new class name if remapping should occur, None otherwise
+ """
+ if class_name not in _TRANSFORMERS_CLASS_REMAPPING:
+ return None
+
+ from .import_utils import is_transformers_version
+
+ mapping = _TRANSFORMERS_CLASS_REMAPPING[class_name]
+ operation, required_version = mapping["transformers_version"]
+
+ # Only remap if the transformers version meets the requirement
+ if is_transformers_version(operation, required_version):
+ new_class = mapping["new_class"]
+ logger.warning(f"{class_name} appears to have been deprecated in transformers. Using {new_class} instead.")
+ return mapping["new_class"]
+
+ return None
+
def deprecate(*args, take_from: Optional[Union[Dict, Any]] = None, standard_warn=True, stacklevel=2):
from .. import __version__
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index bbb9712496..22d2d8c0a5 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -17,6 +17,21 @@ class AdaptiveProjectedGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class AdaptiveProjectedMixGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoGuidance(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +47,21 @@ class AutoGuidance(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class BaseGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ClassifierFreeGuidance(metaclass=DummyObject):
_backends = ["torch"]
@@ -378,6 +408,36 @@ class AutoencoderKLCosmos(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class AutoencoderKLHunyuanImage(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoencoderKLHunyuanImageRefiner(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -528,6 +588,21 @@ class AutoModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class BriaFiboTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class BriaTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -648,6 +723,21 @@ class ConsistencyDecoderVAE(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class ContextParallelConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -843,6 +933,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class HunyuanImageTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HunyuanVideoFramepackTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -903,6 +1008,21 @@ class Kandinsky3UNet(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class Kandinsky5Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class LatteTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1053,6 +1173,21 @@ class OmniGenTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class ParallelConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PixArtTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1083,6 +1218,21 @@ class PriorTransformer(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class PRXTransformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class QwenImageControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1158,6 +1308,21 @@ class SanaTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class SanaVideoTransformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SD3ControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 00792fa55a..e8209403de 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -17,6 +17,36 @@ class FluxAutoBlocks(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class FluxKontextAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxKontextModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -77,6 +107,36 @@ class QwenImageEditModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageEditPlusAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class QwenImageEditPlusModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImageModularPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -422,6 +482,21 @@ class AuraFlowPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class BriaFiboPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class BriaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -977,6 +1052,36 @@ class HunyuanDiTPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class HunyuanImagePipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class HunyuanImageRefinerPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class HunyuanSkyreelsImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1187,6 +1292,21 @@ class Kandinsky3Pipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class Kandinsky5T2VPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1592,6 +1712,21 @@ class LTXPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class LucyEditPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class Lumina2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1817,6 +1952,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class PRXPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1877,6 +2027,21 @@ class QwenImageEditPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class QwenImageEditPlusPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class QwenImageImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -2012,6 +2177,21 @@ class SanaSprintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class SanaVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class SemanticStableDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index 674eb65773..1c0734cf35 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -151,8 +151,8 @@ def check_imports(filename):
missing_packages.append(imp)
if len(missing_packages) > 0:
- raise ImportError(
- "This modeling file requires the following packages that were not found in your environment: "
+ logger.warning(
+ "This modeling file might require the following packages that were not found in your environment: "
f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
)
@@ -247,12 +247,14 @@ def find_pipeline_class(loaded_module):
def get_cached_module_file(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
+ local_dir: Optional[str] = None,
):
"""
Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
@@ -289,12 +291,8 @@ def get_cached_module_file(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`str`: The path to the module inside the cache.
@@ -335,6 +333,7 @@ def get_cached_module_file(
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
submodule = "git"
module_file = pretrained_model_name_or_path + ".py"
@@ -353,10 +352,13 @@ def get_cached_module_file(
resolved_module_file = hf_hub_download(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
+ local_dir=local_dir,
+ revision=revision,
token=token,
)
submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))
@@ -410,12 +412,14 @@ def get_cached_module_file(
get_cached_module_file(
pretrained_model_name_or_path,
f"{module_needed}.py",
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
return os.path.join(full_submodule, module_file)
@@ -424,6 +428,7 @@ def get_cached_module_file(
def get_class_from_dynamic_module(
pretrained_model_name_or_path: Union[str, os.PathLike],
module_file: str,
+ subfolder: Optional[str] = None,
class_name: Optional[str] = None,
cache_dir: Optional[Union[str, os.PathLike]] = None,
force_download: bool = False,
@@ -431,17 +436,13 @@ def get_class_from_dynamic_module(
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
- **kwargs,
+ local_dir: Optional[str] = None,
):
"""
Extracts a class from a module file, present in the local folder or repository of a model.
-
-
- Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
- therefore only be called on trusted repos.
-
-
+ > [!WARNING] > Calling this function will execute the code in the module file found locally or downloaded from the
+ Hub. It should > therefore only be called on trusted repos.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
@@ -476,12 +477,8 @@ def get_class_from_dynamic_module(
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
-
-
- You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
-
-
+ > [!TIP] > You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or
+ [gated > models](https://huggingface.co/docs/hub/models-gated#gated-models).
Returns:
`type`: The class, dynamically imported from the module.
@@ -497,11 +494,13 @@ def get_class_from_dynamic_module(
final_module = get_cached_module_file(
pretrained_model_name_or_path,
module_file,
+ subfolder=subfolder,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
local_files_only=local_files_only,
+ local_dir=local_dir,
)
return get_class_in_module(class_name, final_module)
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index fcdf49156a..d0b05c7d95 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -38,13 +38,13 @@ from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY, HF_HUB_OFFLINE
from huggingface_hub.file_download import REGEX_COMMIT_HASH
from huggingface_hub.utils import (
EntryNotFoundError,
+ HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
is_jinja_available,
validate_hf_hub_args,
)
from packaging import version
-from requests import HTTPError
from .. import __version__
from .constants import (
@@ -113,7 +113,8 @@ def load_or_create_model_card(
Args:
repo_id_or_path (`str`):
- The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
+ The repo id (e.g., "stable-diffusion-v1-5/stable-diffusion-v1-5") or local path where to look for the model
+ card.
token (`str`, *optional*):
Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
details.
@@ -316,7 +317,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
) from e
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{e}"
) from e
@@ -432,7 +433,7 @@ def _get_checkpoint_shard_files(
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here. We have also dealt with EntryNotFoundError.
- except HTTPError as e:
+ except HfHubHTTPError as e:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load {pretrained_model_name_or_path}. You should try"
" again after checking your internet connection."
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index 9399ccd2a7..adf8ed8b06 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -21,6 +21,7 @@ import operator as op
import os
import sys
from collections import OrderedDict, defaultdict
+from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
@@ -225,6 +226,7 @@ _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
+_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
@@ -405,6 +407,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available
+def is_aiter_available():
+ return _aiter_available
+
+
def is_kornia_available():
return _kornia_available
@@ -673,6 +679,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
+@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +693,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
+@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +709,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)
+@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +725,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)
+@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +741,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)
+@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +757,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)
+@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +773,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)
+@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
@@ -775,6 +788,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)
+@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +804,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
+@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
@@ -805,6 +820,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)
+@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +836,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)
+@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +852,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
+@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +868,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
+@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
@@ -865,6 +884,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)
+@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +900,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)
+@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
@@ -895,6 +916,22 @@ def is_flash_attn_version(operation: str, version: str):
return compare_versions(parse(_flash_attn_version), operation, version)
+@cache
+def is_aiter_version(operation: str, version: str):
+ """
+ Compares the current aiter version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _aiter_available:
+ return False
+ return compare_versions(parse(_aiter_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py
index 35691496a1..2b20f6120c 100644
--- a/src/diffusers/utils/outputs.py
+++ b/src/diffusers/utils/outputs.py
@@ -43,12 +43,8 @@ class BaseOutput(OrderedDict):
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
Python dictionary.
-
-
- You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
- first.
-
-
+ > [!WARNING] > You can't unpack a [`BaseOutput`] directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert
+ it to a tuple > first.
"""
def __init_subclass__(cls) -> None:
diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py
index 59b59b47d2..abeb30bca1 100644
--- a/src/diffusers/video_processor.py
+++ b/src/diffusers/video_processor.py
@@ -13,11 +13,12 @@
# limitations under the License.
import warnings
-from typing import List, Optional, Union
+from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
+import torch.nn.functional as F
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist
@@ -111,3 +112,65 @@ class VideoProcessor(VaeImageProcessor):
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
return outputs
+
+ @staticmethod
+ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
+ r"""
+ Returns the binned height and width based on the aspect ratio.
+
+ Args:
+ height (`int`): The height of the image.
+ width (`int`): The width of the image.
+ ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
+
+ Returns:
+ `Tuple[int, int]`: The closest binned height and width.
+ """
+ ar = float(height / width)
+ closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
+ default_hw = ratios[closest_ratio]
+ return int(default_hw[0]), int(default_hw[1])
+
+ @staticmethod
+ def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
+ r"""
+ Resizes and crops a tensor of videos to the specified dimensions.
+
+ Args:
+ samples (`torch.Tensor`):
+ A tensor of shape (N, C, T, H, W) where N is the batch size, C is the number of channels, T is the
+ number of frames, H is the height, and W is the width.
+ new_width (`int`): The desired width of the output videos.
+ new_height (`int`): The desired height of the output videos.
+
+ Returns:
+ `torch.Tensor`: A tensor containing the resized and cropped videos.
+ """
+ orig_height, orig_width = samples.shape[3], samples.shape[4]
+
+ # Check if resizing is needed
+ if orig_height != new_height or orig_width != new_width:
+ ratio = max(new_height / orig_height, new_width / orig_width)
+ resized_width = int(orig_width * ratio)
+ resized_height = int(orig_height * ratio)
+
+ # Reshape to (N*T, C, H, W) for interpolation
+ n, c, t, h, w = samples.shape
+ samples = samples.permute(0, 2, 1, 3, 4).reshape(n * t, c, h, w)
+
+ # Resize
+ samples = F.interpolate(
+ samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
+ )
+
+ # Center Crop
+ start_x = (resized_width - new_width) // 2
+ end_x = start_x + new_width
+ start_y = (resized_height - new_height) // 2
+ end_y = start_y + new_height
+ samples = samples[:, :, start_y:end_y, start_x:end_x]
+
+ # Reshape back to (N, C, T, H, W)
+ samples = samples.reshape(n, t, c, new_height, new_width).permute(0, 2, 1, 3, 4)
+
+ return samples
diff --git a/tests/lora/test_lora_layers_auraflow.py b/tests/lora/test_lora_layers_auraflow.py
index 67084dd6d0..91f63c4b56 100644
--- a/tests/lora/test_lora_layers_auraflow.py
+++ b/tests/lora/test_lora_layers_auraflow.py
@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index 16147f35c7..fa57b4c9c2 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLCogVideoX,
- CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
- scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = {
"num_attention_heads": 4,
diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py
index 3b8a56c403..30eb8fbb63 100644
--- a/tests/lora/test_lora_layers_cogview4.py
+++ b/tests/lora/test_lora_layers_cogview4.py
@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -124,30 +123,26 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(torch_device)
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(torch_device)
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 7d99bcad80..b840d7ac72 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
@@ -123,9 +122,6 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
@@ -171,8 +167,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -219,9 +214,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
# Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
@@ -282,9 +275,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler()
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 8,
@@ -331,6 +323,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ np.random.seed(0)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
@@ -907,6 +900,13 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
assert max_diff < 1e-3
+ def test_flux_kohya_embedders_conversion(self):
+ """Test that embedders load without throwing errors"""
+ self.pipeline.load_lora_weights("rockerBOO/flux-bpo-po-lora")
+ self.pipeline.unload_lora_weights()
+
+ assert True
+
def test_flux_xlabs(self):
self.pipeline.load_lora_weights("XLabs-AI/flux-lora-collection", weight_name="disney_lora.safetensors")
self.pipeline.fuse_lora()
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index 62d045f836..cfd5d3146a 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -254,6 +253,7 @@ class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
expected_slices = Expectations(
{
("cuda", 7): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
+ ("xpu", 3): np.array([0.1013, 0.1924, 0.0078, 0.1021, 0.1929, 0.0078, 0.1023, 0.1919, 0.7402, 0.104, 0.4482, 0.7354, 0.0925, 0.4382, 0.7275, 0.0815]),
}
)
# fmt: on
diff --git a/tests/lora/test_lora_layers_ltx_video.py b/tests/lora/test_lora_layers_ltx_video.py
index a8ad30e448..6ab51a5e51 100644
--- a/tests/lora/test_lora_layers_ltx_video.py
+++ b/tests/lora/test_lora_layers_ltx_video.py
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_lumina2.py b/tests/lora/test_lora_layers_lumina2.py
index 0ebc831b11..0417b05b33 100644
--- a/tests/lora/test_lora_layers_lumina2.py
+++ b/tests/lora/test_lora_layers_lumina2.py
@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False,
)
def test_lora_fuse_nan(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
- out = pipe(**inputs)[0]
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
- self.assertTrue(np.isnan(out).all())
+ self.assertTrue(np.isnan(out).all())
diff --git a/tests/lora/test_lora_layers_mochi.py b/tests/lora/test_lora_layers_mochi.py
index 21cc5f11a3..7be81273db 100644
--- a/tests/lora/test_lora_layers_mochi.py
+++ b/tests/lora/test_lora_layers_mochi.py
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_qwenimage.py b/tests/lora/test_lora_layers_qwenimage.py
index 44ef9b0a37..51de2f8e20 100644
--- a/tests/lora/test_lora_layers_qwenimage.py
+++ b/tests/lora/test_lora_layers_qwenimage.py
@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_sana.py b/tests/lora/test_lora_layers_sana.py
index a08908c610..a860b7b44f 100644
--- a/tests/lora/test_lora_layers_sana.py
+++ b/tests/lora/test_lora_layers_sana.py
@@ -20,7 +20,7 @@ from transformers import Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
-from ..testing_utils import floats_tensor, require_peft_backend
+from ..testing_utils import IS_GITHUB_ACTIONS, floats_tensor, require_peft_backend
sys.path.append(".")
@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
- scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
- scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_kwargs = {"shift": 7.0}
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
@@ -137,3 +136,7 @@ class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in SANA.")
def test_simple_inference_with_text_lora_save_load(self):
pass
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference_denoiser(self):
+ return super().test_layerwise_casting_inference_denoiser()
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 95f6f325e4..228460eaad 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index 0ba80d2be1..5734509b41 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
index d8dde32dd8..ab1f57bfc9 100644
--- a/tests/lora/test_lora_layers_wanvace.py
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
- scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
@@ -165,13 +164,12 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
- scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out"
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 72c1dddaa2..3d4344bb86 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -26,8 +26,6 @@ from parameterized import parameterized
from diffusers import (
AutoencoderKL,
- DDIMScheduler,
- LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import logging
@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None
scheduler_kwargs = None
- scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False
has_three_text_encoders = False
@@ -129,13 +126,20 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ cached_non_lora_output = None
+
+ def get_base_pipe_output(self):
+ if self.cached_non_lora_output is None:
+ self.cached_non_lora_output = self._compute_baseline_output()
+ return self.cached_non_lora_output
+
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
- scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
+ scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
@@ -241,15 +245,16 @@ class PeftLoraLoaderMixinTests:
return noise, input_ids, pipeline_inputs
- # Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
- def get_dummy_tokens(self):
- max_seq_length = 77
+ def _compute_baseline_output(self):
+ components, _, _ = self.get_dummy_components(self.scheduler_cls)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0))
-
- prepared_inputs = {}
- prepared_inputs["input_ids"] = inputs
- return prepared_inputs
+ # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
+ # explicitly.
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ return pipe(**inputs, generator=torch.manual_seed(0))[0]
def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
@@ -319,152 +324,132 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- _, _, inputs = self.get_dummy_inputs()
- output_no_lora = pipe(**inputs)[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
+ assert output_no_lora.shape == self.output_shape
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
-
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
+ self.assertTrue(
+ "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
+ "The LoRA params should be on 'meta' device.",
+ )
+
+ te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
+ set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
+ self.assertTrue(
+ "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
+ "No param should be on 'meta' device.",
+ )
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ self.assertTrue(
+ "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
+ )
+
+ denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
+ set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
+ self.assertTrue(
+ "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
+ )
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
- "meta" in {p.device.type for p in pipe.text_encoder.parameters()},
+ "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
- te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
- set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
+ te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
+ set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
- "meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
+ "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- self.assertTrue(
- "meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
- )
-
- denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
- set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
- self.assertTrue(
- "meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
- )
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
- self.assertTrue(
- "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
- "The LoRA params should be on 'meta' device.",
- )
-
- te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
- set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
- self.assertTrue(
- "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
- "No param should be on 'meta' device.",
- )
-
- _, _, inputs = self.get_dummy_inputs()
- output_lora = pipe(**inputs)[0]
- self.assertTrue(output_lora.shape == self.output_shape)
+ _, _, inputs = self.get_dummy_inputs()
+ output_lora = pipe(**inputs)[0]
+ self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
+ )
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+ # Now, check for `low_cpu_mem_usage.`
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- # Now, check for `low_cpu_mem_usage.`
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
-
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(
- images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
- ),
- "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
- )
+ images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
+ )
def test_simple_inference_with_text_lora_and_scale(self):
"""
@@ -472,411 +457,376 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ output_no_lora = self.get_base_pipe_output()
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
+ output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
- output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
def test_simple_inference_with_text_lora_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe.fuse_lora()
- # Fusing should still keep the LoRA layers
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.fuse_lora()
+ # Fusing should still keep the LoRA layers
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
- ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
-
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
-
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
+
def test_simple_inference_with_partial_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
with different ranks and some adapters removed
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
- text_lora_config = LoraConfig(
- r=4,
- rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
- lora_alpha=4,
- target_modules=self.text_encoder_target_modules,
- init_lora_weights=False,
- use_dora=False,
- )
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, _, _ = self.get_dummy_components()
+ # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
+ text_lora_config = LoraConfig(
+ r=4,
+ rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
+ lora_alpha=4,
+ target_modules=self.text_encoder_target_modules,
+ init_lora_weights=False,
+ use_dora=False,
+ )
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- state_dict = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
- # supports missing layers (PR#8324).
- state_dict = {
- f"text_encoder.{module_name}": param
- for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
- if "text_model.encoder.layers.4" not in module_name
- }
+ state_dict = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
+ # supports missing layers (PR#8324).
+ state_dict = {
+ f"text_encoder.{module_name}": param
+ for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
+ if "text_model.encoder.layers.4" not in module_name
+ }
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- state_dict.update(
- {
- f"text_encoder_2.{module_name}": param
- for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
- if "text_model.encoder.layers.4" not in module_name
- }
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ state_dict.update(
+ {
+ f"text_encoder_2.{module_name}": param
+ for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
+ if "text_model.encoder.layers.4" not in module_name
+ }
+ )
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
- )
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- # Unload lora and load it back using the pipe.load_lora_weights machinery
- pipe.unload_lora_weights()
- pipe.load_lora_weights(state_dict)
+ # Unload lora and load it back using the pipe.load_lora_weights machinery
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(state_dict)
- output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
- "Removing adapters should change the output",
- )
+ output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
+ "Removing adapters should change the output",
+ )
def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ pipe.save_pretrained(tmpdirname)
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
+ pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
+ pipe_from_pretrained.to(torch_device)
- pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
- pipe_from_pretrained.to(torch_device)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
+ "Lora not correctly set in text encoder",
+ )
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
- check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
- "Lora not correctly set in text encoder",
+ check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
- "Lora not correctly set in text encoder 2",
- )
+ images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
- images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
- )
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
def test_simple_inference_with_text_denoiser_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
- )
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
- pipe.unload_lora_weights()
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
-
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results.",
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
+
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
+
+ images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results.",
+ )
+
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ )
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
+ output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+
+ self.assertTrue(
+ np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
+ "Lora + 0 scale should lead to same result as no LoRA",
+ )
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
- not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
+ pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
+ "The scaling parameter has not been correctly restored!",
)
- attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
-
- self.assertTrue(
- not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
-
- attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
- output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
- "Lora + 0 scale should lead to same result as no LoRA",
- )
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
- "The scaling parameter has not been correctly restored!",
- )
-
def test_simple_inference_with_text_lora_denoiser_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
- # Fusing should still keep the LoRA layers
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ # Fusing should still keep the LoRA layers
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
-
- output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
- )
+ output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
+ )
def test_simple_inference_with_text_denoiser_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
- pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.unload_lora_weights()
- # unloading should remove the LoRA layers
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
- )
- self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
+ pipe.unload_lora_weights()
+ # unloading should remove the LoRA layers
+ self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
+ self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertFalse(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly unloaded in text encoder 2",
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertFalse(
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly unloaded in text encoder 2",
+ )
- output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
- "Fused lora should change the output",
- )
+ output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
+ "Fused lora should change the output",
+ )
def test_simple_inference_with_text_denoiser_lora_unfused(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
@@ -885,125 +835,120 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
- self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+ output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
- self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+ output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- # unloading should remove the LoRA layers
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
+ # unloading should remove the LoRA layers
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
- # Fuse and unfuse should lead to the same results
- self.assertTrue(
- np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertTrue(
+ np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter outputs should be different.",
+ )
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter outputs should be different.",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
-
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1024,8 +969,7 @@ class PeftLoraLoaderMixinTests:
def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1054,131 +998,127 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
- weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
- pipe.set_adapters("adapter-1", weights_1)
- output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ pipe.set_adapters("adapter-1", weights_1)
+ output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- weights_2 = {"unet": {"up": 5}}
- pipe.set_adapters("adapter-1", weights_2)
- output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ weights_2 = {"unet": {"up": 5}}
+ pipe.set_adapters("adapter-1", weights_2)
+ output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
- "LoRA weights 1 and 2 should give different results",
- )
- self.assertFalse(
- np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
- "No adapter and LoRA weights 1 should give different results",
- )
- self.assertFalse(
- np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
- "No adapter and LoRA weights 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
+ "LoRA weights 1 and 2 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 1 should give different results",
+ )
+ self.assertFalse(
+ np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
+ "No adapter and LoRA weights 2 should give different results",
+ )
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
+ scales_2 = {"unet": {"down": 5, "mid": 5}}
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.set_adapters("adapter-1", scales_1)
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
- scales_2 = {"unet": {"down": 5, "mid": 5}}
+ pipe.set_adapters("adapter-2", scales_2)
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-1", scales_1)
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2", scales_2)
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
-
- # a mismatching number of adapter_names and adapter_weights should raise an error
- with self.assertRaises(ValueError):
- pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
+ # a mismatching number of adapter_names and adapter_weights should raise an error
+ with self.assertRaises(ValueError):
+ pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
@@ -1274,170 +1214,164 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ pipe.delete_adapters("adapter-1")
+ output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertTrue(
+ np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.delete_adapters("adapter-1")
- output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.delete_adapters("adapter-2")
+ output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
- pipe.delete_adapters("adapter-2")
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ pipe.delete_adapters(["adapter-1", "adapter-2"])
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters(["adapter-1", "adapter-2"])
- pipe.delete_adapters(["adapter-1", "adapter-2"])
-
- output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ pipe.set_adapters("adapter-1")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ pipe.set_adapters("adapter-2")
+ output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-1")
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
- pipe.set_adapters("adapter-2")
- output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # Fuse and unfuse should lead to the same results
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and 2 should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"])
- output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(
+ np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 1 and mixed adapters should give different results",
+ )
- # Fuse and unfuse should lead to the same results
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
- "Adapter 1 and 2 should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Adapter 2 and mixed adapters should give different results",
+ )
- self.assertFalse(
- np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 1 and mixed adapters should give different results",
- )
+ pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
+ output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Adapter 2 and mixed adapters should give different results",
- )
+ self.assertFalse(
+ np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
+ "Weighted adapter and mixed adapter should give different results",
+ )
- pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
- output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ pipe.disable_lora()
+ output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(
- np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
- "Weighted adapter and mixed adapter should give different results",
- )
-
- pipe.disable_lora()
- output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
- "output with no lora and output with lora disabled should give same results",
- )
+ self.assertTrue(
+ np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
+ "output with no lora and output with lora disabled should give same results",
+ )
@skip_mps
@pytest.mark.xfail(
@@ -1446,165 +1380,157 @@ class PeftLoraLoaderMixinTests:
strict=False,
)
def test_lora_fuse_nan(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ # corrupt one LoRA weight with `inf` values
+ with torch.no_grad():
+ if self.unet_kwargs:
+ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
+ "inf"
)
+ else:
+ named_modules = [name for name, _ in pipe.transformer.named_modules()]
+ possible_tower_names = [
+ "transformer_blocks",
+ "blocks",
+ "joint_transformer_blocks",
+ "single_transformer_blocks",
+ ]
+ filtered_tower_names = [
+ tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
+ ]
+ if len(filtered_tower_names) == 0:
+ reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
+ raise ValueError(reason)
+ for tower_name in filtered_tower_names:
+ transformer_tower = getattr(pipe.transformer, tower_name)
+ has_attn1 = any("attn1" in name for name in named_modules)
+ if has_attn1:
+ transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
+ else:
+ transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ # with `safe_fusing=True` we should see an Error
+ with self.assertRaises(ValueError):
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
- # corrupt one LoRA weight with `inf` values
- with torch.no_grad():
- if self.unet_kwargs:
- pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
- "adapter-1"
- ].weight += float("inf")
- else:
- named_modules = [name for name, _ in pipe.transformer.named_modules()]
- possible_tower_names = [
- "transformer_blocks",
- "blocks",
- "joint_transformer_blocks",
- "single_transformer_blocks",
- ]
- filtered_tower_names = [
- tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
- ]
- if len(filtered_tower_names) == 0:
- reason = (
- f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
- )
- raise ValueError(reason)
- for tower_name in filtered_tower_names:
- transformer_tower = getattr(pipe.transformer, tower_name)
- has_attn1 = any("attn1" in name for name in named_modules)
- if has_attn1:
- transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
- else:
- transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
+ # without we should not see an error, but every image will be black
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
+ out = pipe(**inputs)[0]
- # with `safe_fusing=True` we should see an Error
- with self.assertRaises(ValueError):
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
-
- # without we should not see an error, but every image will be black
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
- out = pipe(**inputs)[0]
-
- self.assertTrue(np.isnan(out).all())
+ self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-1"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-1"])
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
- adapter_names = pipe.get_active_adapters()
- self.assertListEqual(adapter_names, ["adapter-2"])
+ adapter_names = pipe.get_active_adapters()
+ self.assertListEqual(adapter_names, ["adapter-2"])
- pipe.set_adapters(["adapter-1", "adapter-2"])
- self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
- # 1.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
+ # 1.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
- dicts_to_be_checked.update({"unet": ["adapter-1"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
- dicts_to_be_checked.update({"transformer": ["adapter-1"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
+ dicts_to_be_checked.update({"unet": ["adapter-1"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
+ dicts_to_be_checked.update({"transformer": ["adapter-1"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- # 2.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ # 2.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- # 3.
- pipe.set_adapters(["adapter-1", "adapter-2"])
+ # 3.
+ pipe.set_adapters(["adapter-1", "adapter-2"])
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
- else:
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
+ if self.unet_kwargs is not None:
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
+ else:
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
- self.assertDictEqual(
- pipe.get_list_adapters(),
- dicts_to_be_checked,
- )
+ self.assertDictEqual(
+ pipe.get_list_adapters(),
+ dicts_to_be_checked,
+ )
- # 4.
- dicts_to_be_checked = {}
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
+ # 4.
+ dicts_to_be_checked = {}
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
- if self.unet_kwargs is not None:
- pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
- dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
- else:
- pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
- dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
+ if self.unet_kwargs is not None:
+ pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
+ dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
+ else:
+ pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
+ dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
- self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
+ self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
- @require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
):
@@ -1612,165 +1538,149 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+ pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config, "adapter-1")
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ denoiser.add_adapter(denoiser_lora_config, "adapter-2")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ )
+ pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
+
+ # set them to multi-adapter inference mode
+ pipe.set_adapters(["adapter-1", "adapter-2"])
+ outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe.set_adapters(["adapter-1"])
+ outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
+ self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+
+ # Fusing should still keep the LoRA layers so output should remain the same
+ outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
+
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
+
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
+ )
+
+ pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
+ self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+
+ # Fusing should still keep the LoRA layers
+ output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
+ "Fused lora should not change the output",
+ )
+ pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
+ self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
+
+ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
+ attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+
+ for lora_scale in [1.0, 0.8]:
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
- denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
+ check_if_lora_correctly_set(pipe.text_encoder_2),
+ "Lora not correctly set in text encoder 2",
)
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
-
- # set them to multi-adapter inference mode
- pipe.set_adapters(["adapter-1", "adapter-2"])
- outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"])
- outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
+ outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
+ pipe.fuse_lora(
+ components=self.pipeline_class._lora_loadable_modules,
+ adapter_names=["adapter-1"],
+ lora_scale=lora_scale,
+ )
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
- # Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
-
- pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
- self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
-
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
- )
-
- pipe.fuse_lora(
- components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
- )
- self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
-
- # Fusing should still keep the LoRA layers
- output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(
- np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
- pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
- self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
-
- def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
- attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
-
- for lora_scale in [1.0, 0.8]:
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
- )
-
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config, "adapter-1")
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
-
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2),
- "Lora not correctly set in text encoder 2",
- )
-
- pipe.set_adapters(["adapter-1"])
- attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
- outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
-
- pipe.fuse_lora(
- components=self.pipeline_class._lora_loadable_modules,
- adapter_names=["adapter-1"],
- lora_scale=lora_scale,
- )
- self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
-
- outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
- self.assertTrue(
- np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
- "Fused lora should not change the output",
- )
- self.assertFalse(
- np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
- "LoRA should change the output",
- )
-
- @require_peft_version_greater(peft_version="0.9.0")
- def test_simple_inference_with_dora(self):
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
- scheduler_cls, use_dora=True
- )
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
-
- output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_dora_lora.shape == self.output_shape)
-
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
-
- output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
-
self.assertFalse(
- np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
- "DoRA lora should change the output",
+ np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
+ "LoRA should change the output",
)
+ def test_simple_inference_with_dora(self):
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_dora_lora.shape == self.output_shape)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+
+ output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertFalse(
+ np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
+ "DoRA lora should change the output",
+ )
+
def test_missing_keys_warning(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1805,9 +1715,8 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -1842,23 +1751,21 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
+ pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
- if self.has_two_text_encoders or self.has_three_text_encoders:
- pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
-
- # Just makes sure it works..
- _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # Just makes sure it works.
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self):
def set_pad_mode(network, mode="circular"):
@@ -1866,28 +1773,26 @@ class PeftLoraLoaderMixinTests:
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
- for scheduler_cls in self.scheduler_classes:
- components, _, _ = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _pad_mode = "circular"
- set_pad_mode(pipe.vae, _pad_mode)
- set_pad_mode(pipe.unet, _pad_mode)
+ components, _, _ = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _pad_mode = "circular"
+ set_pad_mode(pipe.vae, _pad_mode)
+ set_pad_mode(pipe.unet, _pad_mode)
- _, _, inputs = self.get_dummy_inputs()
- _ = pipe(**inputs)[0]
+ _, _, inputs = self.get_dummy_inputs()
+ _ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self):
- scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
- components, _, _ = self.get_dummy_components(scheduler_cls)
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1899,7 +1804,7 @@ class PeftLoraLoaderMixinTests:
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
- self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
+ self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
# test only for text encoder
for lora_module in self.pipeline_class._lora_loadable_modules:
@@ -1925,73 +1830,69 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ output_no_lora = self.get_base_pipe_output()
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
+
+ lora_scale = 0.5
+ attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
+ output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
+ self.assertFalse(
+ np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+
+ pipe.set_adapters("default", lora_scale)
+ output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should change the output",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ "Lora + scale should match the output of `set_adapters()`.",
+ )
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
+ for module_name, module in modules_to_save.items():
+ self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
- pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
-
- lora_scale = 0.5
- attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
- output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertFalse(
- np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
-
- pipe.set_adapters("default", lora_scale)
- output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
- not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
+ not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
- np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
- "Lora + scale should match the output of `set_adapters()`.",
+ np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as attention_kwargs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
+ "Loading from saved checkpoints should give same results as set_adapters().",
)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(
- save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
- )
-
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
-
- for module_name, module in modules_to_save.items():
- self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
-
- output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
- self.assertTrue(
- not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Lora + scale should change the output",
- )
- self.assertTrue(
- np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results as attention_kwargs.",
- )
- self.assertTrue(
- np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
- "Loading from saved checkpoints should give same results as set_adapters().",
- )
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2028,7 +1929,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2114,7 +2015,7 @@ class PeftLoraLoaderMixinTests:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
@@ -2181,7 +2082,7 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
@@ -2211,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
@@ -2231,10 +2132,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
- scheduler_cls, lora_alpha=lora_alpha
- )
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline(
@@ -2280,16 +2178,10 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
- scheduler_cls, lora_alpha=lora_alpha
- )
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(output_no_lora.shape == self.output_shape)
-
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
@@ -2311,8 +2203,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
- scheduler_cls = self.scheduler_classes[0]
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
@@ -2330,51 +2221,48 @@ class PeftLoraLoaderMixinTests:
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
- for scheduler_cls in self.scheduler_classes:
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
- output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ output_no_lora = self.get_base_pipe_output()
- if "text_encoder" in self.pipeline_class._lora_loadable_modules:
- pipe.text_encoder.add_adapter(text_lora_config)
+ if "text_encoder" in self.pipeline_class._lora_loadable_modules:
+ pipe.text_encoder.add_adapter(text_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ if self.has_two_text_encoders or self.has_three_text_encoders:
+ lora_loadable_components = self.pipeline_class._lora_loadable_modules
+ if "text_encoder_2" in lora_loadable_components:
+ pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
+ check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
- denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
- denoiser.add_adapter(denoiser_lora_config)
- self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+ output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
- if self.has_two_text_encoders or self.has_three_text_encoders:
- lora_loadable_components = self.pipeline_class._lora_loadable_modules
- if "text_encoder_2" in lora_loadable_components:
- pipe.text_encoder_2.add_adapter(text_lora_config)
- self.assertTrue(
- check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
- )
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ # First, delete adapter and compare.
+ pipe.delete_adapters(pipe.get_active_adapters()[0])
+ output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
+ self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
- with tempfile.TemporaryDirectory() as tmpdirname:
- modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
- lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
- self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
- self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
-
- # First, delete adapter and compare.
- pipe.delete_adapters(pipe.get_active_adapters()[0])
- output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
- self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
-
- # Then load adapter and compare.
- pipe.load_lora_weights(tmpdirname)
- output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
- self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
+ # Then load adapter and compare.
+ pipe.load_lora_weights(tmpdirname)
+ output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
@@ -2382,7 +2270,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device
offload_device = torch.device("cpu")
- components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
@@ -2399,7 +2287,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
- components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
@@ -2451,7 +2339,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
@@ -2470,7 +2358,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
- components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
diff --git a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
index 7eb830cd50..2476ab92f7 100644
--- a/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_asymmetric_autoencoder_kl.py
@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
diff --git a/tests/models/autoencoders/test_models_autoencoder_cosmos.py b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
index ceccc2364e..5898ae776a 100644
--- a/tests/models/autoencoders/test_models_autoencoder_cosmos.py
+++ b/tests/models/autoencoders/test_models_autoencoder_cosmos.py
@@ -17,13 +17,14 @@ import unittest
from diffusers import AutoencoderKLCosmos
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
-
- @unittest.skip("Unsupported test.")
- def test_forward_with_norm_groups(self):
- pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_dc.py b/tests/models/autoencoders/test_models_autoencoder_dc.py
index 56f172f1c8..d34001e7b9 100644
--- a/tests/models/autoencoders/test_models_autoencoder_dc.py
+++ b/tests/models/autoencoders/test_models_autoencoder_dc.py
@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderDC
-from ...testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
@@ -82,6 +79,6 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- @unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
- def test_forward_with_norm_groups(self):
- pass
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
index 6f91f8bfa9..9813772a7c 100644
--- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
@@ -20,18 +20,15 @@ import torch
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
-from ...testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py
index 662a3f1b80..5f11c6cb0a 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl.py
@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
index 739daf2a49..b6d59489d9 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_cogvideox.py
@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te
inputs_dict = self.dummy_input
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",
diff --git a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
index 6cb427bff8..93f40f44a9 100644
--- a/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
+++ b/tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py
@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
-
- @unittest.skip("Test unsupported.")
- def test_forward_with_norm_groups(self):
- pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
index 21ab3896c8..527be1b4ec 100644
--- a/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
+++ b/tests/models/autoencoders/test_models_autoencoder_ltx_video.py
@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.
pass
-class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
-
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
diff --git a/tests/models/autoencoders/test_models_autoencoder_magvit.py b/tests/models/autoencoders/test_models_autoencoder_magvit.py
index 58cbfc05bd..f7304df140 100644
--- a/tests/models/autoencoders/test_models_autoencoder_magvit.py
+++ b/tests/models/autoencoders/test_models_autoencoder_magvit.py
@@ -18,13 +18,14 @@ import unittest
from diffusers import AutoencoderKLMagvit
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
+
+ @unittest.skip(
+ "Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
+ )
+ def test_enable_disable_slicing(self):
+ pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_mochi.py b/tests/models/autoencoders/test_models_autoencoder_mochi.py
index b8c5aaaa1e..ab8d429a67 100755
--- a/tests/models/autoencoders/test_models_autoencoder_mochi.py
+++ b/tests/models/autoencoders/test_models_autoencoder_mochi.py
@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderKLMochi
-from ...testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2
@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
- @unittest.skip("Unsupported test.")
- def test_forward_with_norm_groups(self):
- """
- tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
- TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
- """
- pass
-
@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""
diff --git a/tests/models/autoencoders/test_models_autoencoder_oobleck.py b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
index eb7bd50f4a..d10e8ba33a 100644
--- a/tests/models/autoencoders/test_models_autoencoder_oobleck.py
+++ b/tests/models/autoencoders/test_models_autoencoder_oobleck.py
@@ -30,13 +30,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
- @unittest.skip("Test unsupported.")
- def test_forward_with_norm_groups(self):
- pass
-
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return
diff --git a/tests/models/autoencoders/test_models_autoencoder_tiny.py b/tests/models/autoencoders/test_models_autoencoder_tiny.py
index 4d1dc69cfa..68232aa12f 100644
--- a/tests/models/autoencoders/test_models_autoencoder_tiny.py
+++ b/tests/models/autoencoders/test_models_autoencoder_tiny.py
@@ -31,13 +31,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny
main_input_name = "sample"
base_precision = 1e-2
@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_enable_disable_tiling(self):
pass
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict)[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict)[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict)[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
@unittest.skip("Test not supported.")
def test_outputs_equivalence(self):
pass
diff --git a/tests/models/autoencoders/test_models_autoencoder_wan.py b/tests/models/autoencoders/test_models_autoencoder_wan.py
index cc9c888681..051098dc7a 100644
--- a/tests/models/autoencoders/test_models_autoencoder_wan.py
+++ b/tests/models/autoencoders/test_models_autoencoder_wan.py
@@ -15,18 +15,17 @@
import unittest
-import torch
-
from diffusers import AutoencoderKLWan
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling(96, 96, 64, 64)
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.05,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
diff --git a/tests/models/autoencoders/test_models_consistency_decoder_vae.py b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
index 7e44edba36..ef04d151ec 100644
--- a/tests/models/autoencoders/test_models_consistency_decoder_vae.py
+++ b/tests/models/autoencoders/test_models_consistency_decoder_vae.py
@@ -31,12 +31,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
+class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
- def test_enable_disable_tiling(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
- _ = inputs_dict.pop("generator")
-
- torch.manual_seed(0)
- output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_tiling()
- output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
- 0.5,
- "VAE tiling should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_tiling()
- output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_tiling.detach().cpu().numpy().all(),
- output_without_tiling_2.detach().cpu().numpy().all(),
- "Without tiling outputs should match with the outputs when tiling is manually disabled.",
- )
-
- def test_enable_disable_slicing(self):
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
-
- torch.manual_seed(0)
- model = self.model_class(**init_dict).to(torch_device)
-
- inputs_dict.update({"return_dict": False})
- _ = inputs_dict.pop("generator")
-
- torch.manual_seed(0)
- output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- torch.manual_seed(0)
- model.enable_slicing()
- output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertLess(
- (output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
- 0.5,
- "VAE slicing should not affect the inference results",
- )
-
- torch.manual_seed(0)
- model.disable_slicing()
- output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
-
- self.assertEqual(
- output_without_slicing.detach().cpu().numpy().all(),
- output_without_slicing_2.detach().cpu().numpy().all(),
- "Without slicing outputs should match with the outputs when slicing is manually disabled.",
- )
-
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
diff --git a/tests/models/autoencoders/test_models_vq.py b/tests/models/autoencoders/test_models_vq.py
index 1c636b0817..b88d24d1f2 100644
--- a/tests/models/autoencoders/test_models_vq.py
+++ b/tests/models/autoencoders/test_models_vq.py
@@ -19,19 +19,15 @@ import torch
from diffusers import VQModel
-from ...testing_utils import (
- backend_manual_seed,
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
+from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
+from ..test_modeling_common import ModelTesterMixin
+from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
-class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
+class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel
main_input_name = "sample"
diff --git a/tests/models/autoencoders/testing_utils.py b/tests/models/autoencoders/testing_utils.py
new file mode 100644
index 0000000000..8ae362ac2e
--- /dev/null
+++ b/tests/models/autoencoders/testing_utils.py
@@ -0,0 +1,147 @@
+import inspect
+
+import numpy as np
+import pytest
+import torch
+
+from diffusers.models.autoencoders.vae import DecoderOutput
+from diffusers.utils.torch_utils import torch_device
+
+
+class AutoencoderTesterMixin:
+ """
+ Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
+ usually don't do slicing and tiling.
+ """
+
+ @staticmethod
+ def _accepts_generator(model):
+ model_sig = inspect.signature(model.forward)
+ accepts_generator = "generator" in model_sig.parameters
+ return accepts_generator
+
+ @staticmethod
+ def _accepts_norm_num_groups(model_class):
+ model_sig = inspect.signature(model_class.__init__)
+ accepts_norm_groups = "norm_num_groups" in model_sig.parameters
+ return accepts_norm_groups
+
+ def test_forward_with_norm_groups(self):
+ if not self._accepts_norm_num_groups(self.model_class):
+ pytest.skip(f"Test not supported for {self.model_class.__name__}")
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ init_dict["norm_num_groups"] = 16
+ init_dict["block_out_channels"] = (16, 32)
+
+ model = self.model_class(**init_dict)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ output = model(**inputs_dict)
+
+ if isinstance(output, dict):
+ output = output.to_tuple()[0]
+
+ self.assertIsNotNone(output)
+ expected_shape = inputs_dict["sample"].shape
+ self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
+
+ def test_enable_disable_tiling(self):
+ if not hasattr(self.model_class, "enable_tiling"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+
+ if not hasattr(model, "use_tiling"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator", None)
+ accepts_generator = self._accepts_generator(model)
+
+ torch.manual_seed(0)
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_tiling = model(**inputs_dict)[0]
+ # Mochi-1
+ if isinstance(output_without_tiling, DecoderOutput):
+ output_without_tiling = output_without_tiling.sample
+
+ torch.manual_seed(0)
+ model.enable_tiling()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_with_tiling = model(**inputs_dict)[0]
+ if isinstance(output_with_tiling, DecoderOutput):
+ output_with_tiling = output_with_tiling.sample
+
+ assert (
+ output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
+ ).max() < 0.5, "VAE tiling should not affect the inference results"
+
+ torch.manual_seed(0)
+ model.disable_tiling()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_tiling_2 = model(**inputs_dict)[0]
+ if isinstance(output_without_tiling_2, DecoderOutput):
+ output_without_tiling_2 = output_without_tiling_2.sample
+
+ assert np.allclose(
+ output_without_tiling.detach().cpu().numpy().all(),
+ output_without_tiling_2.detach().cpu().numpy().all(),
+ ), "Without tiling outputs should match with the outputs when tiling is manually disabled."
+
+ def test_enable_disable_slicing(self):
+ if not hasattr(self.model_class, "enable_slicing"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ torch.manual_seed(0)
+ model = self.model_class(**init_dict).to(torch_device)
+ if not hasattr(model, "use_slicing"):
+ pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
+
+ inputs_dict.update({"return_dict": False})
+ _ = inputs_dict.pop("generator", None)
+ accepts_generator = self._accepts_generator(model)
+
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+
+ torch.manual_seed(0)
+ output_without_slicing = model(**inputs_dict)[0]
+ # Mochi-1
+ if isinstance(output_without_slicing, DecoderOutput):
+ output_without_slicing = output_without_slicing.sample
+
+ torch.manual_seed(0)
+ model.enable_slicing()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_with_slicing = model(**inputs_dict)[0]
+ if isinstance(output_with_slicing, DecoderOutput):
+ output_with_slicing = output_with_slicing.sample
+
+ assert (
+ output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
+ ).max() < 0.5, "VAE slicing should not affect the inference results"
+
+ torch.manual_seed(0)
+ model.disable_slicing()
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ output_without_slicing_2 = model(**inputs_dict)[0]
+ if isinstance(output_without_slicing_2, DecoderOutput):
+ output_without_slicing_2 = output_without_slicing_2.sample
+
+ assert np.allclose(
+ output_without_slicing.detach().cpu().numpy().all(),
+ output_without_slicing_2.detach().cpu().numpy().all(),
+ ), "Without slicing outputs should match with the outputs when slicing is manually disabled."
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index 5e7be62342..6f4c3d544b 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -25,7 +25,6 @@ import traceback
import unittest
import unittest.mock as mock
import uuid
-import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
@@ -37,9 +36,8 @@ import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
from huggingface_hub import ModelCard, delete_repo, snapshot_download, try_to_load_from_cache
-from huggingface_hub.utils import is_jinja_available
+from huggingface_hub.utils import HfHubHTTPError, is_jinja_available
from parameterized import parameterized
-from requests.exceptions import HTTPError
from diffusers.models import FluxTransformer2DModel, SD3Transformer2DModel, UNet2DConditionModel
from diffusers.models.attention_processor import (
@@ -244,8 +242,8 @@ class ModelUtilsTest(unittest.TestCase):
else:
_ = load_model(repo_id)
- warning_message = str(warning.warnings[0].message)
- self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_message)
+ warning_messages = " ".join(str(w.message) for w in warning.warnings)
+ self.assertIn("This serialization format is now deprecated to standardize the serialization", warning_messages)
# Local tests are already covered down below.
@parameterized.expand(
@@ -272,7 +270,7 @@ class ModelUtilsTest(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -296,14 +294,16 @@ class ModelUtilsTest(unittest.TestCase):
error_response = mock.Mock(
status_code=500,
headers={},
- raise_for_status=mock.Mock(side_effect=HTTPError),
+ raise_for_status=mock.Mock(side_effect=HfHubHTTPError("Server down", response=mock.Mock())),
json=mock.Mock(return_value={}),
)
+ client_mock = mock.Mock()
+ client_mock.get.return_value = error_response
with tempfile.TemporaryDirectory() as tmpdir:
model = FluxTransformer2DModel.from_pretrained(repo_id, subfolder="transformer", cache_dir=tmpdir)
- with mock.patch("requests.Session.get", return_value=error_response):
+ with mock.patch("huggingface_hub.hf_api.get_session", return_value=client_mock):
# Should fail with local_files_only=False (network required)
# We would make a network call with model_info
with self.assertRaises(OSError):
@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
class UNetTesterMixin:
+ @staticmethod
+ def _accepts_norm_num_groups(model_class):
+ model_sig = inspect.signature(model_class.__init__)
+ accepts_norm_groups = "norm_num_groups" in model_sig.parameters
+ return accepts_norm_groups
+
def test_forward_with_norm_groups(self):
+ if not self._accepts_norm_num_groups(self.model_class):
+ pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
@@ -1794,11 +1802,6 @@ class ModelTesterMixin:
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
- if self.model_class.__name__ == "QwenImageTransformer2DModel":
- pytest.skip(
- "QwenImageTransformer2DModel doesn't support group offloading with disk. Needs to be investigated."
- )
-
def _has_generator_arg(model):
sig = inspect.signature(model.forward)
params = sig.parameters
@@ -2377,14 +2380,15 @@ class LoraHotSwappingForModelTesterMixin:
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self):
# check possibility to ignore the error/warning
+ from diffusers.loaders.peft import logger
+
lora_config = self.get_lora_config(8, 8, target_modules=["to_q"])
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
- with warnings.catch_warnings(record=True) as w:
- warnings.simplefilter("always") # Capture all warnings
- model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
- self.assertEqual(len(w), 0, f"Expected no warnings, but got: {[str(warn.message) for warn in w]}")
+ # note: assertNoLogs requires Python 3.10+
+ with self.assertNoLogs(logger, level="WARNING"):
+ model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
diff --git a/tests/models/transformers/test_models_transformer_bria_fibo.py b/tests/models/transformers/test_models_transformer_bria_fibo.py
new file mode 100644
index 0000000000..f859f4608b
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_bria_fibo.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import BriaFiboTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = BriaFiboTransformer2DModel
+ main_input_name = "hidden_states"
+ # We override the items here because the transformer under consideration is small.
+ model_split_percents = [0.8, 0.7, 0.7]
+
+ # Skip setting testing with default: AttnProcessor
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_latent_channels = 48
+ num_image_channels = 3
+ height = width = 16
+ sequence_length = 32
+ embedding_dim = 64
+
+ hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
+ image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "timestep": timestep,
+ "text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 16)
+
+ @property
+ def output_shape(self):
+ return (256, 48)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": 1,
+ "in_channels": 48,
+ "num_layers": 1,
+ "num_single_layers": 1,
+ "attention_head_dim": 8,
+ "num_attention_heads": 2,
+ "joint_attention_dim": 64,
+ "text_encoder_dim": 32,
+ "pooled_projection_dim": None,
+ "axes_dims_rope": [0, 4, 4],
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"BriaFiboTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/transformers/test_models_transformer_prx.py b/tests/models/transformers/test_models_transformer_prx.py
new file mode 100644
index 0000000000..1387625d5e
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_prx.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = PRXTransformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 16, 16)
+
+ def prepare_dummy_input(self, height=16, width=16):
+ batch_size = 1
+ num_latent_channels = 16
+ sequence_length = 16
+ embedding_dim = 1792
+
+ hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
+
+ return {
+ "hidden_states": hidden_states,
+ "timestep": timestep,
+ "encoder_hidden_states": encoder_hidden_states,
+ }
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "patch_size": 2,
+ "context_in_dim": 1792,
+ "hidden_size": 1792,
+ "mlp_ratio": 3.5,
+ "num_heads": 28,
+ "depth": 4, # Smaller depth for testing
+ "axes_dim": [32, 32],
+ "theta": 10_000,
+ }
+ inputs_dict = self.prepare_dummy_input()
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"PRXTransformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/models/transformers/test_models_transformer_sana_video.py b/tests/models/transformers/test_models_transformer_sana_video.py
new file mode 100644
index 0000000000..ff564ed891
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_sana_video.py
@@ -0,0 +1,97 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 16
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (16, 2, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (16, 2, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "in_channels": 16,
+ "out_channels": 16,
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "num_layers": 2,
+ "num_cross_attention_heads": 2,
+ "cross_attention_head_dim": 12,
+ "cross_attention_dim": 24,
+ "caption_channels": 16,
+ "mlp_ratio": 2.5,
+ "dropout": 0.0,
+ "attention_bias": False,
+ "sample_size": 8,
+ "patch_size": (1, 2, 2),
+ "norm_elementwise_affine": False,
+ "norm_eps": 1e-6,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SanaVideoTransformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
+
+
+class SanaVideoTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
+ model_class = SanaVideoTransformer3DModel
+
+ def prepare_init_args_and_inputs_for_common(self):
+ return SanaVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
diff --git a/tests/modular_pipelines/flux/__init__.py b/tests/modular_pipelines/flux/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/modular_pipelines/flux/test_modular_pipeline_flux.py b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py
new file mode 100644
index 0000000000..a29fd43614
--- /dev/null
+++ b/tests/modular_pipelines/flux/test_modular_pipeline_flux.py
@@ -0,0 +1,172 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import tempfile
+
+import numpy as np
+import PIL
+import torch
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.modular_pipelines import (
+ FluxAutoBlocks,
+ FluxKontextAutoBlocks,
+ FluxKontextModularPipeline,
+ FluxModularPipeline,
+ ModularPipeline,
+)
+
+from ...testing_utils import floats_tensor, torch_device
+from ..test_modular_pipelines_common import ModularPipelineTesterMixin
+
+
+class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxModularPipeline
+ pipeline_blocks_class = FluxAutoBlocks
+ repo = "hf-internal-testing/tiny-flux-modular"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ return inputs
+
+
+class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxModularPipeline
+ pipeline_blocks_class = FluxAutoBlocks
+ repo = "hf-internal-testing/tiny-flux-modular"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
+ batch_params = frozenset(["prompt", "image"])
+
+ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
+ pipeline = super().get_pipeline(components_manager, torch_dtype)
+
+ # Override `vae_scale_factor` here as currently, `image_processor` is initialized with
+ # fixed constants instead of
+ # https://github.com/huggingface/diffusers/blob/d54622c2679d700b425ad61abce9b80fc36212c0/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L230C9-L232C10
+ pipeline.image_processor = VaeImageProcessor(vae_scale_factor=2)
+ return pipeline
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB")
+
+ inputs["image"] = init_image
+ inputs["strength"] = 0.5
+
+ return inputs
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+ pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+
+
+class TestFluxKontextModularPipelineFast(ModularPipelineTesterMixin):
+ pipeline_class = FluxKontextModularPipeline
+ pipeline_blocks_class = FluxKontextAutoBlocks
+ repo = "hf-internal-testing/tiny-flux-kontext-pipe"
+
+ params = frozenset(["prompt", "height", "width", "guidance_scale", "image"])
+ batch_params = frozenset(["prompt", "image"])
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_sequence_length": 48,
+ "output_type": "pt",
+ }
+ image = PIL.Image.new("RGB", (32, 32), 0)
+
+ inputs["image"] = image
+ inputs["max_area"] = inputs["height"] * inputs["width"]
+ inputs["_auto_resize"] = False
+
+ return inputs
+
+ def test_save_from_pretrained(self):
+ pipes = []
+ base_pipe = self.get_pipeline().to(torch_device)
+ pipes.append(base_pipe)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ base_pipe.save_pretrained(tmpdirname)
+
+ pipe = ModularPipeline.from_pretrained(tmpdirname).to(torch_device)
+ pipe.load_components(torch_dtype=torch.float32)
+ pipe.to(torch_device)
+ pipe.image_processor = VaeImageProcessor(vae_scale_factor=2)
+
+ pipes.append(pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs()
+ image = pipe(**inputs, output="images")
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
diff --git a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
index d05f818135..ea54b2bdff 100644
--- a/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
+++ b/tests/modular_pipelines/stable_diffusion_xl/test_modular_pipeline_stable_diffusion_xl.py
@@ -14,93 +14,43 @@
# limitations under the License.
import random
-import unittest
from typing import Any, Dict
import numpy as np
import torch
from PIL import Image
-from diffusers import (
- ClassifierFreeGuidance,
- StableDiffusionXLAutoBlocks,
- StableDiffusionXLModularPipeline,
-)
+from diffusers import ClassifierFreeGuidance, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from diffusers.loaders import ModularIPAdapterMixin
-from ...models.unets.test_models_unet_2d_condition import (
- create_ip_adapter_state_dict,
-)
-from ...testing_utils import (
- enable_full_determinism,
- floats_tensor,
- torch_device,
-)
-from ..test_modular_pipelines_common import (
- ModularPipelineTesterMixin,
-)
+from ...models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict
+from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
+from ..test_modular_pipelines_common import ModularPipelineTesterMixin
enable_full_determinism()
-class SDXLModularTests:
+class SDXLModularTesterMixin:
"""
This mixin defines method to create pipeline, base input and base test across all SDXL modular tests.
"""
- pipeline_class = StableDiffusionXLModularPipeline
- pipeline_blocks_class = StableDiffusionXLAutoBlocks
- repo = "hf-internal-testing/tiny-sdxl-modular"
- params = frozenset(
- [
- "prompt",
- "height",
- "width",
- "negative_prompt",
- "cross_attention_kwargs",
- "image",
- "mask_image",
- ]
- )
- batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
-
- def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
- pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
- pipeline.load_components(torch_dtype=torch_dtype)
- return pipeline
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
def _test_stable_diffusion_xl_euler(self, expected_image_shape, expected_slice, expected_max_diff=1e-2):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
sd_pipe = self.get_pipeline()
- sd_pipe = sd_pipe.to(device)
+ sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(device)
+ inputs = self.get_dummy_inputs()
image = sd_pipe(**inputs, output="images")
image_slice = image[0, -3:, -3:, -1]
assert image.shape == expected_image_shape
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < expected_max_diff, (
- "Image Slice does not match expected slice"
- )
+ max_diff = torch.abs(image_slice.flatten() - expected_slice).max()
+ assert max_diff < expected_max_diff, f"Image slice does not match expected slice. Max Difference: {max_diff}"
-class SDXLModularIPAdapterTests:
+class SDXLModularIPAdapterTesterMixin:
"""
This mixin is designed to test IP Adapter.
"""
@@ -139,7 +89,7 @@ class SDXLModularIPAdapterTests:
if "image" in parameters and "strength" in parameters:
inputs["num_inference_steps"] = 4
- inputs["output_type"] = "np"
+ inputs["output_type"] = "pt"
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
@@ -164,7 +114,7 @@ class SDXLModularIPAdapterTests:
cross_attention_dim = pipe.unet.config.get("cross_attention_dim")
# forward pass without ip adapter
- inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs, output="images")
else:
@@ -175,7 +125,7 @@ class SDXLModularIPAdapterTests:
pipe.unet._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
- inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(0.0)
@@ -184,7 +134,7 @@ class SDXLModularIPAdapterTests:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
- inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)]
pipe.set_ip_adapter_scale(42.0)
@@ -192,8 +142,8 @@ class SDXLModularIPAdapterTests:
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
- max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
- max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
+ max_diff_without_adapter_scale = torch.abs(output_without_adapter_scale - output_without_adapter).max()
+ max_diff_with_adapter_scale = torch.abs(output_with_adapter_scale - output_without_adapter).max()
assert max_diff_without_adapter_scale < expected_max_diff, (
"Output without ip-adapter must be same as normal inference"
@@ -206,7 +156,7 @@ class SDXLModularIPAdapterTests:
pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2])
# forward pass with multi ip adapter, but scale=0 which should have no effect
- inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([0.0, 0.0])
@@ -215,7 +165,7 @@ class SDXLModularIPAdapterTests:
output_without_multi_adapter_scale = output_without_multi_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with multi ip adapter, but with scale of adapter weights
- inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs())
inputs["ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
inputs["negative_ip_adapter_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2
pipe.set_ip_adapter_scale([42.0, 42.0])
@@ -223,10 +173,10 @@ class SDXLModularIPAdapterTests:
if expected_pipe_slice is not None:
output_with_multi_adapter_scale = output_with_multi_adapter_scale[0, -3:, -3:, -1].flatten()
- max_diff_without_multi_adapter_scale = np.abs(
+ max_diff_without_multi_adapter_scale = torch.abs(
output_without_multi_adapter_scale - output_without_adapter
).max()
- max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max()
+ max_diff_with_multi_adapter_scale = torch.abs(output_with_multi_adapter_scale - output_without_adapter).max()
assert max_diff_without_multi_adapter_scale < expected_max_diff, (
"Output without multi-ip-adapter must be same as normal inference"
)
@@ -235,7 +185,7 @@ class SDXLModularIPAdapterTests:
)
-class SDXLModularControlNetTests:
+class SDXLModularControlNetTesterMixin:
"""
This mixin is designed to test ControlNet.
"""
@@ -274,24 +224,26 @@ class SDXLModularControlNetTests:
pipe.set_progress_bar_config(disable=None)
# forward pass without controlnet
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
output_without_controlnet = pipe(**inputs, output="images")
output_without_controlnet = output_without_controlnet[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but scale=0 which should have no effect
- inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
inputs["controlnet_conditioning_scale"] = 0.0
output_without_controlnet_scale = pipe(**inputs, output="images")
output_without_controlnet_scale = output_without_controlnet_scale[0, -3:, -3:, -1].flatten()
# forward pass with single controlnet, but with scale of adapter weights
- inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
inputs["controlnet_conditioning_scale"] = 42.0
output_with_controlnet_scale = pipe(**inputs, output="images")
output_with_controlnet_scale = output_with_controlnet_scale[0, -3:, -3:, -1].flatten()
- max_diff_without_controlnet_scale = np.abs(output_without_controlnet_scale - output_without_controlnet).max()
- max_diff_with_controlnet_scale = np.abs(output_with_controlnet_scale - output_without_controlnet).max()
+ max_diff_without_controlnet_scale = torch.abs(
+ output_without_controlnet_scale - output_without_controlnet
+ ).max()
+ max_diff_with_controlnet_scale = torch.abs(output_with_controlnet_scale - output_without_controlnet).max()
assert max_diff_without_controlnet_scale < expected_max_diff, (
"Output without controlnet must be same as normal inference"
@@ -307,21 +259,21 @@ class SDXLModularControlNetTests:
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
- inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
- inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs(torch_device))
+ inputs = self._modify_inputs_for_controlnet_test(self.get_dummy_inputs())
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
- max_diff = np.abs(out_cfg - out_no_cfg).max()
+ max_diff = torch.abs(out_cfg - out_no_cfg).max()
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
-class SDXLModularGuiderTests:
+class SDXLModularGuiderTesterMixin:
def test_guider_cfg(self):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)
@@ -331,13 +283,13 @@ class SDXLModularGuiderTests:
guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")
# forward pass with CFG applied
guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")
assert out_cfg.shape == out_no_cfg.shape
@@ -345,30 +297,57 @@ class SDXLModularGuiderTests:
assert max_diff > 1e-2, "Output with CFG must be different from normal inference"
-class SDXLModularPipelineFastTests(
- SDXLModularTests,
- SDXLModularIPAdapterTests,
- SDXLModularControlNetTests,
- SDXLModularGuiderTests,
+class TestSDXLModularPipelineFast(
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ SDXLModularGuiderTesterMixin,
ModularPipelineTesterMixin,
- unittest.TestCase,
):
"""Test cases for Stable Diffusion XL modular pipeline fast tests."""
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ repo = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "output_type": "pt",
+ }
+ return inputs
+
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
- expected_image_shape=(1, 64, 64, 3),
- expected_slice=[
- 0.5966781,
- 0.62939394,
- 0.48465094,
- 0.51573336,
- 0.57593524,
- 0.47035995,
- 0.53410417,
- 0.51436996,
- 0.47313565,
- ],
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor(
+ [
+ 0.5966781,
+ 0.62939394,
+ 0.48465094,
+ 0.51573336,
+ 0.57593524,
+ 0.47035995,
+ 0.53410417,
+ 0.51436996,
+ 0.47313565,
+ ],
+ device=torch_device,
+ ),
expected_max_diff=1e-2,
)
@@ -376,39 +355,65 @@ class SDXLModularPipelineFastTests(
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-class SDXLImg2ImgModularPipelineFastTests(
- SDXLModularTests,
- SDXLModularIPAdapterTests,
- SDXLModularControlNetTests,
- SDXLModularGuiderTests,
+class TestSDXLImg2ImgModularPipelineFast(
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ SDXLModularGuiderTesterMixin,
ModularPipelineTesterMixin,
- unittest.TestCase,
):
"""Test cases for Stable Diffusion XL image-to-image modular pipeline fast tests."""
- def get_dummy_inputs(self, device, seed=0):
- inputs = super().get_dummy_inputs(device, seed)
- image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
- image = image / 2 + 0.5
- inputs["image"] = image
- inputs["strength"] = 0.8
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ repo = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ "image",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt", "image"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
+ def get_dummy_inputs(self, seed=0):
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "output_type": "pt",
+ }
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device)
+ image = image.cpu().permute(0, 2, 3, 1)[0]
+ init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+
+ inputs["image"] = init_image
+ inputs["strength"] = 0.5
return inputs
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
- expected_image_shape=(1, 64, 64, 3),
- expected_slice=[
- 0.56943184,
- 0.4702148,
- 0.48048905,
- 0.6235963,
- 0.551138,
- 0.49629188,
- 0.60031277,
- 0.5688907,
- 0.43996853,
- ],
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor(
+ [
+ 0.56943184,
+ 0.4702148,
+ 0.48048905,
+ 0.6235963,
+ 0.551138,
+ 0.49629188,
+ 0.60031277,
+ 0.5688907,
+ 0.43996853,
+ ],
+ device=torch_device,
+ ),
expected_max_diff=1e-2,
)
@@ -417,20 +422,43 @@ class SDXLImg2ImgModularPipelineFastTests(
class SDXLInpaintingModularPipelineFastTests(
- SDXLModularTests,
- SDXLModularIPAdapterTests,
- SDXLModularControlNetTests,
- SDXLModularGuiderTests,
+ SDXLModularTesterMixin,
+ SDXLModularIPAdapterTesterMixin,
+ SDXLModularControlNetTesterMixin,
+ SDXLModularGuiderTesterMixin,
ModularPipelineTesterMixin,
- unittest.TestCase,
):
"""Test cases for Stable Diffusion XL inpainting modular pipeline fast tests."""
+ pipeline_class = StableDiffusionXLModularPipeline
+ pipeline_blocks_class = StableDiffusionXLAutoBlocks
+ repo = "hf-internal-testing/tiny-sdxl-modular"
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "negative_prompt",
+ "cross_attention_kwargs",
+ "image",
+ "mask_image",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])
+ expected_image_output_shape = (1, 3, 64, 64)
+
def get_dummy_inputs(self, device, seed=0):
- inputs = super().get_dummy_inputs(device, seed)
+ generator = self.get_generator(seed)
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 4,
+ "output_type": "pt",
+ }
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
+
# create mask
image[8:, 8:, :] = 255
mask_image = Image.fromarray(np.uint8(image)).convert("L").resize((64, 64))
@@ -443,18 +471,21 @@ class SDXLInpaintingModularPipelineFastTests(
def test_stable_diffusion_xl_euler(self):
self._test_stable_diffusion_xl_euler(
- expected_image_shape=(1, 64, 64, 3),
- expected_slice=[
- 0.40872607,
- 0.38842705,
- 0.34893104,
- 0.47837183,
- 0.43792963,
- 0.5332134,
- 0.3716843,
- 0.47274873,
- 0.45000193,
- ],
+ expected_image_shape=self.expected_image_output_shape,
+ expected_slice=torch.tensor(
+ [
+ 0.40872607,
+ 0.38842705,
+ 0.34893104,
+ 0.47837183,
+ 0.43792963,
+ 0.5332134,
+ 0.3716843,
+ 0.47274873,
+ 0.45000193,
+ ],
+ device=torch_device,
+ ),
expected_max_diff=1e-2,
)
diff --git a/tests/modular_pipelines/test_modular_pipelines_common.py b/tests/modular_pipelines/test_modular_pipelines_common.py
index d309fcf353..1325e5c1de 100644
--- a/tests/modular_pipelines/test_modular_pipelines_common.py
+++ b/tests/modular_pipelines/test_modular_pipelines_common.py
@@ -1,9 +1,7 @@
import gc
import tempfile
-import unittest
from typing import Callable, Union
-import numpy as np
import torch
import diffusers
@@ -19,17 +17,9 @@ from ..testing_utils import (
)
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
@require_torch
class ModularPipelineTesterMixin:
"""
- This mixin is designed to be used with unittest.TestCase classes.
It provides a set of common tests for each modular pipeline,
including:
- test_pipeline_call_signature: check if the pipeline's __call__ method has all required parameters
@@ -57,9 +47,8 @@ class ModularPipelineTesterMixin:
]
)
- def get_generator(self, seed):
- device = torch_device if torch_device != "mps" else "cpu"
- generator = torch.Generator(device).manual_seed(seed)
+ def get_generator(self, seed=0):
+ generator = torch.Generator("cpu").manual_seed(seed)
return generator
@property
@@ -82,13 +71,7 @@ class ModularPipelineTesterMixin:
"See existing pipeline tests for reference."
)
- def get_pipeline(self):
- raise NotImplementedError(
- "You need to implement `get_pipeline(self)` in the child test class. "
- "See existing pipeline tests for reference."
- )
-
- def get_dummy_inputs(self, device, seed=0):
+ def get_dummy_inputs(self, seed=0):
raise NotImplementedError(
"You need to implement `get_dummy_inputs(self, device, seed)` in the child test class. "
"See existing pipeline tests for reference."
@@ -123,20 +106,23 @@ class ModularPipelineTesterMixin:
"See existing pipeline tests for reference."
)
- def setUp(self):
+ def setup_method(self):
# clean up the VRAM before each test
- super().setUp()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
+ def teardown_method(self):
# clean up the VRAM after each test in case of CUDA runtime errors
- super().tearDown()
torch.compiler.reset()
gc.collect()
backend_empty_cache(torch_device)
+ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32):
+ pipeline = self.pipeline_blocks_class().init_pipeline(self.repo, components_manager=components_manager)
+ pipeline.load_components(torch_dtype=torch_dtype)
+ return pipeline
+
def test_pipeline_call_signature(self):
pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names
@@ -156,7 +142,7 @@ class ModularPipelineTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
inputs["generator"] = self.get_generator(0)
logger = logging.get_logger(pipe.__module__)
@@ -196,7 +182,7 @@ class ModularPipelineTesterMixin:
pipe = self.get_pipeline()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
# Reset generator in case it is has been used in self.get_dummy_inputs
inputs["generator"] = self.get_generator(0)
@@ -226,10 +212,9 @@ class ModularPipelineTesterMixin:
assert output_batch.shape[0] == batch_size
- max_diff = np.abs(to_np(output_batch[0]) - to_np(output[0])).max()
+ max_diff = torch.abs(output_batch[0] - output[0]).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
@@ -240,13 +225,13 @@ class ModularPipelineTesterMixin:
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
output = pipe(**inputs, output="images")
- fp16_inputs = self.get_dummy_inputs(torch_device)
+ fp16_inputs = self.get_dummy_inputs()
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
@@ -283,8 +268,8 @@ class ModularPipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
- output = pipe(**self.get_dummy_inputs("cpu"), output="images")
- assert np.isnan(to_np(output)).sum() == 0, "CPU Inference returns NaN"
+ output = pipe(**self.get_dummy_inputs(), output="images")
+ assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"
@require_accelerator
def test_inference_is_not_nan(self):
@@ -292,8 +277,8 @@ class ModularPipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to(torch_device)
- output = pipe(**self.get_dummy_inputs(torch_device), output="images")
- assert np.isnan(to_np(output)).sum() == 0, "Accelerator Inference returns NaN"
+ output = pipe(**self.get_dummy_inputs(), output="images")
+ assert torch.isnan(output).sum() == 0, "Accelerator Inference returns NaN"
def test_num_images_per_prompt(self):
pipe = self.get_pipeline()
@@ -309,7 +294,7 @@ class ModularPipelineTesterMixin:
for batch_size in batch_sizes:
for num_images_per_prompt in num_images_per_prompts:
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
for key in inputs.keys():
if key in self.batch_params:
@@ -329,12 +314,12 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in [base_pipe, offload_pipe]:
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_save_from_pretrained(self):
pipes = []
@@ -351,9 +336,9 @@ class ModularPipelineTesterMixin:
image_slices = []
for pipe in pipes:
- inputs = self.get_dummy_inputs(torch_device)
+ inputs = self.get_dummy_inputs()
image = pipe(**inputs, output="images")
image_slices.append(image[0, -3:, -3:, -1].flatten())
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
diff --git a/tests/others/test_attention_backends.py b/tests/others/test_attention_backends.py
new file mode 100644
index 0000000000..2e5a2fc82b
--- /dev/null
+++ b/tests/others/test_attention_backends.py
@@ -0,0 +1,157 @@
+"""
+This test suite exists for the maintainers currently. It's not run in our CI at the moment.
+
+Once attention backends become more mature, we can consider including this in our CI.
+
+To run this test suite:
+
+```bash
+export RUN_ATTENTION_BACKEND_TESTS=yes
+export DIFFUSERS_ENABLE_HUB_KERNELS=yes
+
+pytest tests/others/test_attention_backends.py
+```
+
+Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
+"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
+
+Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
+with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
+aiter 0.1.5.post4.dev20+ga25e55e79.
+"""
+
+import os
+
+import pytest
+import torch
+
+
+pytestmark = pytest.mark.skipif(
+ os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
+)
+from diffusers import FluxPipeline # noqa: E402
+from diffusers.utils import is_torch_version # noqa: E402
+
+
+# fmt: off
+FORWARD_CASES = [
+ ("flash_hub", None),
+ (
+ "_flash_3_hub",
+ torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
+ ),
+ (
+ "native",
+ torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
+ ),
+ (
+ "_native_cudnn",
+ torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
+ ),
+ (
+ "aiter",
+ torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
+ )
+]
+
+COMPILE_CASES = [
+ ("flash_hub", None, True),
+ (
+ "_flash_3_hub",
+ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "native",
+ torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "_native_cudnn",
+ torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
+ True,
+ ),
+ (
+ "aiter",
+ torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
+ True,
+ )
+]
+# fmt: on
+
+INFER_KW = {
+ "prompt": "dance doggo dance",
+ "height": 256,
+ "width": 256,
+ "num_inference_steps": 2,
+ "guidance_scale": 3.5,
+ "max_sequence_length": 128,
+ "output_type": "pt",
+}
+
+
+def _backend_is_probably_supported(pipe, name: str):
+ try:
+ pipe.transformer.set_attention_backend(name)
+ return pipe, True
+ except Exception:
+ return False
+
+
+def _check_if_slices_match(output, expected_slice):
+ img = output.images.detach().cpu()
+ generated_slice = img.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
+
+
+@pytest.fixture(scope="session")
+def device():
+ if not torch.cuda.is_available():
+ pytest.skip("CUDA is required for these tests.")
+ return torch.device("cuda:0")
+
+
+@pytest.fixture(scope="session")
+def pipe(device):
+ repo_id = "black-forest-labs/FLUX.1-dev"
+ pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
+ pipe.set_progress_bar_config(disable=True)
+ return pipe
+
+
+@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
+def test_forward(pipe, backend_name, expected_slice):
+ out = _backend_is_probably_supported(pipe, backend_name)
+ if isinstance(out, bool):
+ pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
+
+ modified_pipe = out[0]
+ out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
+ _check_if_slices_match(out, expected_slice)
+
+
+@pytest.mark.parametrize(
+ "backend_name,expected_slice,error_on_recompile",
+ COMPILE_CASES,
+ ids=[c[0] for c in COMPILE_CASES],
+)
+def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
+ if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
+ pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
+
+ out = _backend_is_probably_supported(pipe, backend_name)
+ if isinstance(out, bool):
+ pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
+
+ modified_pipe = out[0]
+ modified_pipe.transformer.compile(fullgraph=True)
+
+ torch.compiler.reset()
+ with (
+ torch._inductor.utils.fresh_inductor_cache(),
+ torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
+ ):
+ out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
+
+ _check_if_slices_match(out, expected_slice)
diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py
index e4bc5cc110..14ff1272a2 100644
--- a/tests/pipelines/audioldm2/test_audioldm2.py
+++ b/tests/pipelines/audioldm2/test_audioldm2.py
@@ -138,10 +138,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
patch_stride=2,
patch_embed_input_channels=4,
)
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config,
- audio_config=audio_branch_config,
- projection_dim=16,
+ text_encoder_config = ClapConfig(
+ text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=16
)
text_encoder = ClapModel(text_encoder_config)
tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
diff --git a/tests/pipelines/bria_fibo/__init__.py b/tests/pipelines/bria_fibo/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
new file mode 100644
index 0000000000..76b41114f8
--- /dev/null
+++ b/tests/pipelines/bria_fibo/test_pipeline_bria_fibo.py
@@ -0,0 +1,139 @@
+# Copyright 2024 Bria AI and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer
+from transformers.models.smollm3.modeling_smollm3 import SmolLM3Config, SmolLM3ForCausalLM
+
+from diffusers import (
+ AutoencoderKLWan,
+ BriaFiboPipeline,
+ FlowMatchEulerDiscreteScheduler,
+)
+from diffusers.models.transformers.transformer_bria_fibo import BriaFiboTransformer2DModel
+from tests.pipelines.test_pipelines_common import PipelineTesterMixin
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+
+enable_full_determinism()
+
+
+class BriaFiboPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = BriaFiboPipeline
+ params = frozenset(["prompt", "height", "width", "guidance_scale"])
+ batch_params = frozenset(["prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = False
+ test_group_offloading = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = BriaFiboTransformer2DModel(
+ patch_size=1,
+ in_channels=16,
+ num_layers=1,
+ num_single_layers=1,
+ attention_head_dim=8,
+ num_attention_heads=2,
+ joint_attention_dim=64,
+ text_encoder_dim=32,
+ pooled_projection_dim=None,
+ axes_dims_rope=[0, 4, 4],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=160,
+ decoder_base_dim=256,
+ num_res_blocks=2,
+ out_channels=12,
+ patch_size=2,
+ scale_factor_spatial=16,
+ scale_factor_temporal=4,
+ temperal_downsample=[False, True, True],
+ z_dim=16,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ text_encoder = SmolLM3ForCausalLM(SmolLM3Config(hidden_size=32))
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ components = {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "transformer": transformer,
+ "vae": vae,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "{'text': 'A painting of a squirrel eating a burger'}",
+ "negative_prompt": "bad, ugly",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "np",
+ }
+ return inputs
+
+ @unittest.skip(reason="will not be supported due to dim-fusion")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
+
+ def test_bria_fibo_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+ assert max_diff > 1e-6
+
+ def test_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components())
+ pipe = pipe.to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (64, 64), (32, 64)]
+ for height, width in height_width_pairs:
+ expected_height = height
+ expected_width = width
+
+ inputs.update({"height": height, "width": width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
diff --git a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
index 9619843779..bf31f2abcf 100644
--- a/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
+++ b/tests/pipelines/controlnet_hunyuandit/test_controlnet_hunyuandit.py
@@ -155,7 +155,7 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
if torch_device == "xpu":
expected_slice = np.array(
- [0.6376953, 0.84375, 0.58691406, 0.48046875, 0.43652344, 0.5517578, 0.54248047, 0.5644531, 0.48217773]
+ [0.6948242, 0.89160156, 0.59375, 0.5078125, 0.57910156, 0.6035156, 0.58447266, 0.53564453, 0.52246094]
)
else:
expected_slice = np.array(
diff --git a/tests/pipelines/easyanimate/test_easyanimate.py b/tests/pipelines/easyanimate/test_easyanimate.py
index 2dbb8639f1..5cb2a232bb 100644
--- a/tests/pipelines/easyanimate/test_easyanimate.py
+++ b/tests/pipelines/easyanimate/test_easyanimate.py
@@ -48,6 +48,7 @@ class EasyAnimatePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ test_xformers_attention = False
required_optional_params = frozenset(
[
"num_inference_steps",
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index c3e8517d64..1ddbd4ba3d 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -15,6 +15,7 @@ from diffusers import (
)
from ...testing_utils import (
+ Expectations,
backend_empty_cache,
nightly,
numpy_cosine_similarity_distance,
@@ -276,10 +277,14 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
# fmt: off
- expected_slice = np.array(
- [0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
- dtype=np.float32,
+
+ expected_slices = Expectations(
+ {
+ ("cuda", None): np.array([0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203], dtype=np.float32,),
+ ("xpu", 3): np.array([0.3301, 0.3281, 0.3359, 0.3203, 0.3203, 0.3281, 0.3281, 0.3301, 0.3340, 0.3281, 0.3320, 0.3359, 0.3281, 0.3301, 0.3320, 0.3242, 0.3301, 0.3281, 0.3242, 0.3320, 0.3320, 0.3281, 0.3320, 0.3320, 0.3262, 0.3320, 0.3301, 0.3301, 0.3359, 0.3320], dtype=np.float32,),
+ }
)
+ expected_slice = expected_slices.get_expectation()
# fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
diff --git a/tests/pipelines/hidream_image/test_pipeline_hidream.py b/tests/pipelines/hidream_image/test_pipeline_hidream.py
index ec8d36e1d3..ddf39ba4c1 100644
--- a/tests/pipelines/hidream_image/test_pipeline_hidream.py
+++ b/tests/pipelines/hidream_image/test_pipeline_hidream.py
@@ -47,8 +47,8 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
required_optional_params = PipelineTesterMixin.required_optional_params
+ test_xformers_attention = False
test_layerwise_casting = True
supports_dduf = False
diff --git a/tests/pipelines/hunyuan_image_21/__init__.py b/tests/pipelines/hunyuan_image_21/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py b/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
new file mode 100644
index 0000000000..e4b2c686b8
--- /dev/null
+++ b/tests/pipelines/hunyuan_image_21/test_hunyuanimage.py
@@ -0,0 +1,290 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import (
+ ByT5Tokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+ Qwen2Tokenizer,
+ T5Config,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AdaptiveProjectedMixGuidance,
+ AutoencoderKLHunyuanImage,
+ FlowMatchEulerDiscreteScheduler,
+ HunyuanImagePipeline,
+ HunyuanImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism
+from ..test_pipelines_common import (
+ FirstBlockCacheTesterMixin,
+ PipelineTesterMixin,
+ to_np,
+)
+
+
+enable_full_determinism()
+
+
+class HunyuanImagePipelineFastTests(
+ PipelineTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
+):
+ pipeline_class = HunyuanImagePipeline
+ params = frozenset(["prompt", "height", "width"])
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+ test_attention_slicing = False
+ supports_dduf = False
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1, guidance_embeds: bool = False):
+ torch.manual_seed(0)
+ transformer = HunyuanImageTransformer2DModel(
+ in_channels=4,
+ out_channels=4,
+ num_attention_heads=4,
+ attention_head_dim=8,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ num_refiner_layers=1,
+ patch_size=(1, 1),
+ guidance_embeds=guidance_embeds,
+ text_embed_dim=32,
+ text_embed_2_dim=32,
+ rope_axes_dim=(4, 4),
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanImage(
+ in_channels=3,
+ out_channels=3,
+ latent_channels=4,
+ block_out_channels=(32, 64, 64, 64),
+ layers_per_block=1,
+ scaling_factor=0.476986,
+ spatial_compression_ratio=8,
+ sample_size=128,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ if not guidance_embeds:
+ torch.manual_seed(0)
+ guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=2)
+ ocr_guider = AdaptiveProjectedMixGuidance(adaptive_projected_guidance_start_step=3)
+ else:
+ guider = None
+ ocr_guider = None
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [2, 2, 4],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 32,
+ "intermediate_size": 32,
+ "num_heads": 2,
+ "out_hidden_size": 32,
+ },
+ hidden_size=32,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ torch.manual_seed(0)
+ t5_config = T5Config(
+ d_model=32,
+ d_kv=4,
+ d_ff=16,
+ num_layers=2,
+ num_heads=2,
+ relative_attention_num_buckets=8,
+ relative_attention_max_distance=32,
+ vocab_size=256,
+ feed_forward_proj="gated-gelu",
+ dense_act_fn="gelu_new",
+ is_encoder_decoder=False,
+ use_cache=False,
+ tie_word_embeddings=False,
+ )
+ text_encoder_2 = T5EncoderModel(t5_config)
+ tokenizer_2 = ByT5Tokenizer()
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "guider": guider,
+ "ocr_guider": ocr_guider,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 5,
+ "height": 16,
+ "width": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.6252659, 0.51482046, 0.60799813, 0.59267783, 0.488082, 0.5857634, 0.523781, 0.58028054, 0.5674121]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_inference_guider(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ pipe.guider = pipe.guider.new(guidance_scale=1000)
+ pipe.ocr_guider = pipe.ocr_guider.new(guidance_scale=1000)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.61494756, 0.49616697, 0.60327923, 0.6115793, 0.49047345, 0.56977504, 0.53066164, 0.58880305, 0.5570612]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_inference_with_distilled_guidance(self):
+ device = "cpu"
+
+ components = self.get_dummy_components(guidance_embeds=True)
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ inputs["distilled_guidance_scale"] = 3.5
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 16, 16))
+
+ expected_slice_np = np.array(
+ [0.63667065, 0.5187377, 0.66757566, 0.6320319, 0.4913387, 0.54813194, 0.5335031, 0.5736143, 0.5461346]
+ )
+ output_slice = generated_image[0, -3:, -3:].flatten().cpu().numpy()
+
+ self.assertTrue(
+ np.abs(output_slice - expected_slice_np).max() < 1e-3,
+ f"output_slice: {output_slice}, expected_slice_np: {expected_slice_np}",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(tile_sample_min_size=96)
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @unittest.skip("TODO: Test not supported for now because needs to be adjusted to work with guiders.")
+ def test_encode_prompt_works_in_isolation(self):
+ pass
diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py
index 911c6e49ae..6207e71df8 100644
--- a/tests/pipelines/kandinsky/test_kandinsky.py
+++ b/tests/pipelines/kandinsky/test_kandinsky.py
@@ -18,11 +18,13 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
+from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -215,6 +217,11 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummy = Dummies()
return dummy.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py
index d744d10821..eba8976597 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_combined.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py
@@ -16,8 +16,10 @@
import unittest
import numpy as np
+import pytest
from diffusers import KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyInpaintCombinedPipeline
+from diffusers.utils import is_transformers_version
from ...testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
from ..test_pipelines_common import PipelineTesterMixin
@@ -73,6 +75,11 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
)
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -181,6 +188,11 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
@@ -292,6 +304,11 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
inputs.pop("negative_image_embeds")
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
index 4074c8db22..6d1b43a24f 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py
@@ -18,6 +18,7 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
@@ -31,6 +32,7 @@ from diffusers import (
VQModel,
)
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
+from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -237,6 +239,11 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_img2img(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
index b789a63cdd..e2f4aa2a4f 100644
--- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
+++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py
@@ -18,12 +18,14 @@ import random
import unittest
import numpy as np
+import pytest
import torch
from PIL import Image
from transformers import XLMRobertaTokenizerFast
from diffusers import DDIMScheduler, KandinskyInpaintPipeline, KandinskyPriorPipeline, UNet2DConditionModel, VQModel
from diffusers.pipelines.kandinsky.text_encoder import MCLIPConfig, MultilingualCLIP
+from diffusers.utils import is_transformers_version
from ...testing_utils import (
backend_empty_cache,
@@ -231,6 +233,11 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
dummies = Dummies()
return dummies.get_dummy_inputs(device=device, seed=seed)
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.56.2"),
+ reason="Latest transformers changes the slices",
+ strict=False,
+ )
def test_kandinsky_inpaint(self):
device = "cpu"
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 4054e38c56..8f8e58a8c4 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -29,6 +29,7 @@ from diffusers import (
)
from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -290,4 +291,11 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
assert image.shape == (512, 512, 3)
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
- assert max_diff < 2e-4
+ expected_max_diffs = Expectations(
+ {
+ ("xpu", 3): 2e-3,
+ ("cuda", 7): 2e-4,
+ }
+ )
+ expected_max_diff = expected_max_diffs.get_expectation()
+ assert max_diff < expected_max_diff
diff --git a/tests/pipelines/kandinsky5/__init__.py b/tests/pipelines/kandinsky5/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/kandinsky5/test_kandinsky5.py b/tests/pipelines/kandinsky5/test_kandinsky5.py
new file mode 100644
index 0000000000..47fccb632a
--- /dev/null
+++ b/tests/pipelines/kandinsky5/test_kandinsky5.py
@@ -0,0 +1,306 @@
+# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+from transformers import (
+ CLIPTextConfig,
+ CLIPTextModel,
+ CLIPTokenizer,
+ Qwen2_5_VLConfig,
+ Qwen2_5_VLForConditionalGeneration,
+ Qwen2VLProcessor,
+)
+
+from diffusers import (
+ AutoencoderKLHunyuanVideo,
+ FlowMatchEulerDiscreteScheduler,
+ Kandinsky5T2VPipeline,
+ Kandinsky5Transformer3DModel,
+)
+
+from ...testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class Kandinsky5T2VPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = Kandinsky5T2VPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+
+ # Define required optional parameters for your pipeline
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ "max_sequence_length",
+ ]
+ )
+
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLHunyuanVideo(
+ in_channels=3,
+ out_channels=3,
+ spatial_compression_ratio=8,
+ temporal_compression_ratio=4,
+ latent_channels=4,
+ block_out_channels=(8, 8, 8, 8),
+ layers_per_block=1,
+ norm_num_groups=4,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler(shift=7.0)
+
+ # Dummy Qwen2.5-VL model
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2VLProcessor.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
+
+ # Dummy CLIP model
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
+ tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+
+ torch.manual_seed(0)
+ transformer = Kandinsky5Transformer3DModel(
+ in_visual_dim=4,
+ in_text_dim=16, # Match tiny Qwen2.5-VL hidden size
+ in_text_dim2=32, # Match tiny CLIP hidden size
+ time_dim=32,
+ out_visual_dim=4,
+ patch_size=(1, 2, 2),
+ model_dim=48,
+ ff_dim=128,
+ num_text_blocks=1,
+ num_visual_blocks=1,
+ axes_dims=(8, 8, 8),
+ visual_cond=False,
+ )
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder.eval(),
+ "tokenizer": tokenizer,
+ "text_encoder_2": text_encoder_2.eval(),
+ "tokenizer_2": tokenizer_2,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "A cat dancing",
+ "negative_prompt": "blurry, low quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "num_frames": 5,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+
+ # Check video shape: (batch, frames, channel, height, width)
+ expected_shape = (1, 5, 3, 32, 32)
+ self.assertEqual(video.shape, expected_shape)
+
+ # Check specific values
+ expected_slice = torch.tensor(
+ [
+ 0.4330,
+ 0.4254,
+ 0.4285,
+ 0.3835,
+ 0.4253,
+ 0.4196,
+ 0.3704,
+ 0.3714,
+ 0.4999,
+ 0.5346,
+ 0.4795,
+ 0.4637,
+ 0.4930,
+ 0.5124,
+ 0.4902,
+ 0.4570,
+ ]
+ )
+
+ generated_slice = video.flatten()
+ # Take first 8 and last 8 values for comparison
+ video_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(video_slice, expected_slice, atol=1e-3),
+ f"video_slice: {video_slice}, expected_slice: {expected_slice}",
+ )
+
+ def test_inference_batch_single_identical(self):
+ # Override to test batch single identical with video
+ super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=1e-2)
+
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-3, rtol=1e-3):
+ components = self.get_dummy_components()
+
+ text_component_names = ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]
+ text_components = {k: (v if k in text_component_names else None) for k, v in components.items()}
+ non_text_components = {k: (v if k not in text_component_names else None) for k, v in components.items()}
+
+ pipe_with_just_text_encoder = self.pipeline_class(**text_components)
+ pipe_with_just_text_encoder = pipe_with_just_text_encoder.to(torch_device)
+
+ pipe_without_text_encoders = self.pipeline_class(**non_text_components)
+ pipe_without_text_encoders = pipe_without_text_encoders.to(torch_device)
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+
+ # Compute `encode_prompt()`.
+
+ # Test single prompt
+ prompt = "A cat dancing"
+ with torch.no_grad():
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
+ prompt, device=torch_device, max_sequence_length=16
+ )
+
+ # Check shapes
+ self.assertEqual(prompt_embeds_qwen.shape, (1, 4, 16)) # [batch, seq_len, embed_dim]
+ self.assertEqual(prompt_embeds_clip.shape, (1, 32)) # [batch, embed_dim]
+ self.assertEqual(prompt_cu_seqlens.shape, (2,)) # [batch + 1]
+
+ # Test batch of prompts
+ prompts = ["A cat dancing", "A dog running"]
+ with torch.no_grad():
+ batch_embeds_qwen, batch_embeds_clip, batch_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
+ prompts, device=torch_device, max_sequence_length=16
+ )
+
+ # Check batch size
+ self.assertEqual(batch_embeds_qwen.shape, (len(prompts), 4, 16))
+ self.assertEqual(batch_embeds_clip.shape, (len(prompts), 32))
+ self.assertEqual(len(batch_cu_seqlens), len(prompts) + 1) # [0, len1, len1+len2]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["guidance_scale"] = 1.0
+
+ # baseline output: full pipeline
+ pipe_out = pipe(**inputs).frames
+
+ # test against pipeline call with pre-computed prompt embeds
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["guidance_scale"] = 1.0
+
+ with torch.no_grad():
+ prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = pipe_with_just_text_encoder.encode_prompt(
+ inputs["prompt"], device=torch_device, max_sequence_length=inputs["max_sequence_length"]
+ )
+
+ inputs["prompt"] = None
+ inputs["prompt_embeds_qwen"] = prompt_embeds_qwen
+ inputs["prompt_embeds_clip"] = prompt_embeds_clip
+ inputs["prompt_cu_seqlens"] = prompt_cu_seqlens
+
+ pipe_out_2 = pipe_without_text_encoders(**inputs)[0]
+
+ self.assertTrue(
+ torch.allclose(pipe_out, pipe_out_2, atol=atol, rtol=rtol),
+ f"max diff: {torch.max(torch.abs(pipe_out - pipe_out_2))}",
+ )
+
+ @unittest.skip("Kandinsky5T2VPipeline does not support attention slicing")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("Kandinsky5T2VPipeline does not support xformers")
+ def test_xformers_attention_forwardGenerator_pass(self):
+ pass
+
+ @unittest.skip("Kandinsky5T2VPipeline does not support VAE slicing")
+ def test_vae_slicing(self):
+ pass
diff --git a/tests/pipelines/marigold/test_marigold_depth.py b/tests/pipelines/marigold/test_marigold_depth.py
index 3e8ccbf5c0..3c85305992 100644
--- a/tests/pipelines/marigold/test_marigold_depth.py
+++ b/tests/pipelines/marigold/test_marigold_depth.py
@@ -33,6 +33,7 @@ from diffusers import (
)
from ...testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -356,7 +357,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f32_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f32_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=False,
device=torch_device,
@@ -369,7 +370,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -382,7 +383,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G2024_S1_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G2024_S1_P768_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -395,12 +396,23 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S2_P768_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S2_P768_E1_B1_M1(self):
+ # fmt: off
+ expected_slices = Expectations(
+ {
+ ("cuda", 7): np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ ("xpu", 3): np.array([0.1084, 0.1096, 0.1108, 0.1080, 0.1083, 0.1080,
+ 0.1085, 0.1057, 0.0996]),
+ }
+ )
+ expected_slice = expected_slices.get_expectation()
+ # fmt: on
+
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
generator_seed=0,
- expected_slice=np.array([0.1085, 0.1098, 0.1110, 0.1081, 0.1085, 0.1082, 0.1085, 0.1057, 0.0996]),
+ expected_slice=expected_slice,
num_inference_steps=2,
processing_resolution=768,
ensemble_size=1,
@@ -408,7 +420,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -421,7 +433,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E3_B1_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E3_B1_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -435,7 +447,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P768_E4_B2_M1(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P768_E4_B2_M1(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
@@ -449,7 +461,7 @@ class MarigoldDepthPipelineIntegrationTests(unittest.TestCase):
match_input_resolution=True,
)
- def test_marigold_depth_einstein_f16_cuda_G0_S1_P512_E1_B1_M0(self):
+ def test_marigold_depth_einstein_f16_accelerator_G0_S1_P512_E1_B1_M0(self):
self._test_marigold_depth(
is_fp16=True,
device=torch_device,
diff --git a/tests/pipelines/omnigen/test_pipeline_omnigen.py b/tests/pipelines/omnigen/test_pipeline_omnigen.py
index 28648aa76f..1a758b7050 100644
--- a/tests/pipelines/omnigen/test_pipeline_omnigen.py
+++ b/tests/pipelines/omnigen/test_pipeline_omnigen.py
@@ -22,7 +22,7 @@ class OmniGenPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = OmniGenPipeline
params = frozenset(["prompt", "guidance_scale"])
batch_params = frozenset(["prompt"])
-
+ test_xformers_attention = False
test_layerwise_casting = True
def get_dummy_components(self):
diff --git a/tests/pipelines/pag/test_pag_sd_inpaint.py b/tests/pipelines/pag/test_pag_sd_inpaint.py
index 709df68370..754158bbf1 100644
--- a/tests/pipelines/pag/test_pag_sd_inpaint.py
+++ b/tests/pipelines/pag/test_pag_sd_inpaint.py
@@ -255,7 +255,7 @@ class StableDiffusionPAGInpaintPipelineFastTests(
@require_torch_accelerator
class StableDiffusionPAGPipelineIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusionPAGInpaintPipeline
- repo_id = "runwayml/stable-diffusion-v1-5"
+ repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
def setUp(self):
super().setUp()
diff --git a/tests/pipelines/prx/__init__.py b/tests/pipelines/prx/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/pipelines/prx/test_pipeline_prx.py b/tests/pipelines/prx/test_pipeline_prx.py
new file mode 100644
index 0000000000..46c6a5760e
--- /dev/null
+++ b/tests/pipelines/prx/test_pipeline_prx.py
@@ -0,0 +1,265 @@
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from transformers import AutoTokenizer
+from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+
+from diffusers.models import AutoencoderDC, AutoencoderKL
+from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
+from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import is_transformers_version
+
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+@pytest.mark.xfail(
+ condition=is_transformers_version(">", "4.57.1"),
+ reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
+ strict=False,
+)
+class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = PRXPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ @classmethod
+ def setUpClass(cls):
+ # Ensure PRXPipeline has an _execution_device property expected by __call__
+ if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
+ try:
+ setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
+ except Exception:
+ pass
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = PRXTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ context_in_dim=8,
+ hidden_size=8,
+ mlp_ratio=2.0,
+ num_heads=2,
+ depth=1,
+ axes_dim=[2, 2],
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0,
+ scaling_factor=1.0,
+ ).eval()
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+ tokenizer.model_max_length = 64
+
+ torch.manual_seed(0)
+
+ encoder_params = {
+ "vocab_size": tokenizer.vocab_size,
+ "hidden_size": 8,
+ "intermediate_size": 16,
+ "num_hidden_layers": 1,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 1,
+ "head_dim": 4,
+ "max_position_embeddings": 64,
+ "layer_types": ["full_attention"],
+ "attention_bias": False,
+ "attention_dropout": 0.0,
+ "dropout_rate": 0.0,
+ "hidden_activation": "gelu_pytorch_tanh",
+ "rms_norm_eps": 1e-06,
+ "attn_logit_softcapping": 50.0,
+ "final_logit_softcapping": 30.0,
+ "query_pre_attn_scalar": 4,
+ "rope_theta": 10000.0,
+ "sliding_window": 4096,
+ }
+ encoder_config = T5GemmaModuleConfig(**encoder_params)
+ text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
+ text_encoder = T5GemmaEncoder(text_encoder_config)
+
+ return {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ return {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "pt",
+ "use_resolution_binning": False,
+ }
+
+ def test_inference(self):
+ device = "cpu"
+ components = self.get_dummy_components()
+ pipe = PRXPipeline(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ try:
+ pipe.register_to_config(_execution_device="cpu")
+ except Exception:
+ pass
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.zeros(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_inputs(self):
+ components = self.get_dummy_components()
+ pipe = PRXPipeline(**components)
+ pipe = pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+ try:
+ pipe.register_to_config(_execution_device="cpu")
+ except Exception:
+ pass
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ def callback_inputs_subset(pipe, i, t, callback_kwargs):
+ for tensor_name in callback_kwargs.keys():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ def callback_inputs_all(pipe, i, t, callback_kwargs):
+ for tensor_name in pipe._callback_tensor_inputs:
+ assert tensor_name in callback_kwargs
+ for tensor_name in callback_kwargs.keys():
+ assert tensor_name in pipe._callback_tensor_inputs
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs("cpu")
+
+ inputs["callback_on_step_end"] = callback_inputs_subset
+ inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
+ _ = pipe(**inputs)[0]
+
+ inputs["callback_on_step_end"] = callback_inputs_all
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ def to_np_local(tensor):
+ if isinstance(tensor, torch.Tensor):
+ return tensor.detach().cpu().numpy()
+ return tensor
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
+ self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
+
+ def test_inference_with_autoencoder_dc(self):
+ """Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ torch.manual_seed(0)
+ vae_dc = AutoencoderDC(
+ in_channels=3,
+ latent_channels=4,
+ attention_head_dim=2,
+ encoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ decoder_block_types=(
+ "ResBlock",
+ "EfficientViTBlock",
+ ),
+ encoder_block_out_channels=(8, 8),
+ decoder_block_out_channels=(8, 8),
+ encoder_qkv_multiscales=((), (5,)),
+ decoder_qkv_multiscales=((), (5,)),
+ encoder_layers_per_block=(1, 1),
+ decoder_layers_per_block=(1, 1),
+ upsample_block_type="interpolate",
+ downsample_block_type="stride_conv",
+ decoder_norm_types="rms_norm",
+ decoder_act_fns="silu",
+ ).eval()
+
+ components["vae"] = vae_dc
+
+ pipe = PRXPipeline(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ expected_scale_factor = vae_dc.spatial_compression_ratio
+ self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs)[0]
+ generated_image = image[0]
+
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+ expected_image = torch.zeros(3, 32, 32)
+ max_diff = np.abs(generated_image - expected_image).max()
+ self.assertLessEqual(max_diff, 1e10)
diff --git a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
index c78e5cb233..188106b49b 100644
--- a/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
+++ b/tests/pipelines/qwenimage/test_qwenimage_controlnet.py
@@ -44,7 +44,6 @@ class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
batch_params = frozenset(["prompt", "negative_prompt", "control_image"])
image_params = frozenset(["control_image"])
image_latents_params = frozenset(["latents"])
-
required_optional_params = frozenset(
[
"num_inference_steps",
@@ -59,7 +58,7 @@ class QwenControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
supports_dduf = False
- test_xformers_attention = True
+ test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
diff --git a/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
new file mode 100644
index 0000000000..6faf347282
--- /dev/null
+++ b/tests/pipelines/qwenimage/test_qwenimage_edit_plus.py
@@ -0,0 +1,253 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import pytest
+import torch
+from PIL import Image
+from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, Qwen2VLProcessor
+
+from diffusers import (
+ AutoencoderKLQwenImage,
+ FlowMatchEulerDiscreteScheduler,
+ QwenImageEditPlusPipeline,
+ QwenImageTransformer2DModel,
+)
+
+from ...testing_utils import enable_full_determinism, torch_device
+from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin, to_np
+
+
+enable_full_determinism()
+
+
+class QwenImageEditPlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = QwenImageEditPlusPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["prompt", "image"])
+ image_params = frozenset(["image"])
+ image_latents_params = frozenset(["latents"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ def get_dummy_components(self):
+ tiny_ckpt_id = "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
+
+ torch.manual_seed(0)
+ transformer = QwenImageTransformer2DModel(
+ patch_size=2,
+ in_channels=16,
+ out_channels=4,
+ num_layers=2,
+ attention_head_dim=16,
+ num_attention_heads=3,
+ joint_attention_dim=16,
+ guidance_embeds=False,
+ axes_dims_rope=(8, 4, 4),
+ )
+
+ torch.manual_seed(0)
+ z_dim = 4
+ vae = AutoencoderKLQwenImage(
+ base_dim=z_dim * 6,
+ z_dim=z_dim,
+ dim_mult=[1, 2, 4],
+ num_res_blocks=1,
+ temperal_downsample=[False, True],
+ latents_mean=[0.0] * z_dim,
+ latents_std=[1.0] * z_dim,
+ )
+
+ torch.manual_seed(0)
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ torch.manual_seed(0)
+ config = Qwen2_5_VLConfig(
+ text_config={
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_hidden_layers": 2,
+ "num_attention_heads": 2,
+ "num_key_value_heads": 2,
+ "rope_scaling": {
+ "mrope_section": [1, 1, 2],
+ "rope_type": "default",
+ "type": "default",
+ },
+ "rope_theta": 1000000.0,
+ },
+ vision_config={
+ "depth": 2,
+ "hidden_size": 16,
+ "intermediate_size": 16,
+ "num_heads": 2,
+ "out_hidden_size": 16,
+ },
+ hidden_size=16,
+ vocab_size=152064,
+ vision_end_token_id=151653,
+ vision_start_token_id=151652,
+ vision_token_id=151654,
+ )
+ text_encoder = Qwen2_5_VLForConditionalGeneration(config)
+ tokenizer = Qwen2Tokenizer.from_pretrained(tiny_ckpt_id)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "processor": Qwen2VLProcessor.from_pretrained(tiny_ckpt_id),
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ image = Image.new("RGB", (32, 32))
+ inputs = {
+ "prompt": "dance monkey",
+ "image": [image, image],
+ "negative_prompt": "bad quality",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "true_cfg_scale": 1.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image = pipe(**inputs).images
+ generated_image = image[0]
+ self.assertEqual(generated_image.shape, (3, 32, 32))
+
+ # fmt: off
+ expected_slice = torch.tensor([[0.5637, 0.6341, 0.6001, 0.5620, 0.5794, 0.5498, 0.5757, 0.6389, 0.4174, 0.3597, 0.5649, 0.4894, 0.4969, 0.5255, 0.4083, 0.4986]])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ def test_attention_slicing_forward_pass(
+ self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
+ ):
+ if not self.test_attention_slicing:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ output_without_slicing = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=1)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing1 = pipe(**inputs)[0]
+
+ pipe.enable_attention_slicing(slice_size=2)
+ inputs = self.get_dummy_inputs(generator_device)
+ output_with_slicing2 = pipe(**inputs)[0]
+
+ if test_max_difference:
+ max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
+ max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
+ self.assertLess(
+ max(max_diff1, max_diff2),
+ expected_max_diff,
+ "Attention slicing should not affect the inference results",
+ )
+
+ def test_vae_tiling(self, expected_diff_max: float = 0.2):
+ generator_device = "cpu"
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe.to("cpu")
+ pipe.set_progress_bar_config(disable=None)
+
+ # Without tiling
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_without_tiling = pipe(**inputs)[0]
+
+ # With tiling
+ pipe.vae.enable_tiling(
+ tile_sample_min_height=96,
+ tile_sample_min_width=96,
+ tile_sample_stride_height=64,
+ tile_sample_stride_width=64,
+ )
+ inputs = self.get_dummy_inputs(generator_device)
+ inputs["height"] = inputs["width"] = 128
+ output_with_tiling = pipe(**inputs)[0]
+
+ self.assertLess(
+ (to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
+ expected_diff_max,
+ "VAE tiling should not affect the inference results",
+ )
+
+ @pytest.mark.xfail(condition=True, reason="Preconfigured embeddings need to be revisited.", strict=True)
+ def test_encode_prompt_works_in_isolation(self, extra_required_param_value_dict=None, atol=1e-4, rtol=1e-4):
+ super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict, atol, rtol)
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_num_images_per_prompt():
+ super().test_num_images_per_prompt()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_consistent():
+ super().test_inference_batch_consistent()
+
+ @pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
+ def test_inference_batch_single_identical():
+ super().test_inference_batch_single_identical()
diff --git a/tests/pipelines/sana/test_sana.py b/tests/pipelines/sana/test_sana.py
index 34ea3079b1..f23303c966 100644
--- a/tests/pipelines/sana/test_sana.py
+++ b/tests/pipelines/sana/test_sana.py
@@ -23,6 +23,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel
from ...testing_utils import (
+ IS_GITHUB_ACTIONS,
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
@@ -304,6 +305,10 @@ class SanaPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
+
@slow
@require_torch_accelerator
diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py
index 043e276fcb..df14d935ed 100644
--- a/tests/pipelines/sana/test_sana_controlnet.py
+++ b/tests/pipelines/sana/test_sana_controlnet.py
@@ -28,10 +28,7 @@ from diffusers import (
)
from diffusers.utils.torch_utils import randn_tensor
-from ...testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -326,3 +323,7 @@ class SanaControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/sana/test_sana_sprint.py b/tests/pipelines/sana/test_sana_sprint.py
index fee2304dce..0d45205ea8 100644
--- a/tests/pipelines/sana/test_sana_sprint.py
+++ b/tests/pipelines/sana/test_sana_sprint.py
@@ -21,10 +21,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
-from ...testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -300,3 +297,7 @@ class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py
index c218abb8e9..5de5c7f446 100644
--- a/tests/pipelines/sana/test_sana_sprint_img2img.py
+++ b/tests/pipelines/sana/test_sana_sprint_img2img.py
@@ -22,10 +22,7 @@ from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
from diffusers.utils.torch_utils import randn_tensor
-from ...testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
@@ -312,3 +309,7 @@ class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
def test_float16_inference(self):
# Requires higher tolerance as model seems very sensitive to dtype
super().test_float16_inference(expected_max_diff=0.08)
+
+ @unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
+ def test_layerwise_casting_inference(self):
+ super().test_layerwise_casting_inference()
diff --git a/tests/pipelines/sana/test_sana_video.py b/tests/pipelines/sana/test_sana_video.py
new file mode 100644
index 0000000000..9f360a942a
--- /dev/null
+++ b/tests/pipelines/sana/test_sana_video.py
@@ -0,0 +1,225 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gc
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
+
+from diffusers import AutoencoderKLWan, DPMSolverMultistepScheduler, SanaVideoPipeline, SanaVideoTransformer3DModel
+
+from ...testing_utils import (
+ backend_empty_cache,
+ enable_full_determinism,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SanaVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SanaVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = DPMSolverMultistepScheduler()
+
+ torch.manual_seed(0)
+ text_encoder_config = Gemma2Config(
+ head_dim=16,
+ hidden_size=8,
+ initializer_range=0.02,
+ intermediate_size=64,
+ max_position_embeddings=8192,
+ model_type="gemma2",
+ num_attention_heads=2,
+ num_hidden_layers=1,
+ num_key_value_heads=2,
+ vocab_size=8,
+ attn_implementation="eager",
+ )
+ text_encoder = Gemma2Model(text_encoder_config)
+ tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ transformer = SanaVideoTransformer3DModel(
+ in_channels=16,
+ out_channels=16,
+ num_attention_heads=2,
+ attention_head_dim=12,
+ num_layers=2,
+ num_cross_attention_heads=2,
+ cross_attention_head_dim=12,
+ cross_attention_dim=24,
+ caption_channels=8,
+ mlp_ratio=2.5,
+ dropout=0.0,
+ attention_bias=False,
+ sample_size=8,
+ patch_size=(1, 2, 2),
+ norm_elementwise_affine=False,
+ norm_eps=1e-6,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "",
+ "negative_prompt": "",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 32,
+ "width": 32,
+ "frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "complex_human_instruction": [],
+ "use_resolution_binning": False,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 32, 32))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ def test_save_load_local(self, expected_max_difference=5e-4):
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ self.assertLess(max_diff, expected_max_difference)
+
+ # TODO(aryan): Create a dummy gemma model with smol vocab size
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_consistent(self):
+ pass
+
+ @unittest.skip(
+ "A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
+ )
+ def test_inference_batch_single_identical(self):
+ pass
+
+ def test_float16_inference(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_float16_inference(expected_max_diff=0.08)
+
+ def test_save_load_float16(self):
+ # Requires higher tolerance as model seems very sensitive to dtype
+ super().test_save_load_float16(expected_max_diff=0.2)
+
+
+@slow
+@require_torch_accelerator
+class SanaVideoPipelineIntegrationTests(unittest.TestCase):
+ prompt = "Evening, backlight, side lighting, soft light, high contrast, mid-shot, centered composition, clean solo shot, warm color. A young Caucasian man stands in a forest."
+
+ def setUp(self):
+ super().setUp()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def tearDown(self):
+ super().tearDown()
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ @unittest.skip("TODO: test needs to be implemented")
+ def test_sana_video_480p(self):
+ pass
diff --git a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
index f8267186db..0bc821b7e6 100644
--- a/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
+++ b/tests/pipelines/stable_cascade/test_stable_cascade_prior.py
@@ -17,11 +17,13 @@ import gc
import unittest
import numpy as np
+import pytest
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.models import StableCascadeUNet
+from diffusers.utils import is_transformers_version
from diffusers.utils.import_utils import is_peft_available
from ...testing_utils import (
@@ -154,6 +156,11 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
}
return inputs
+ @pytest.mark.xfail(
+ condition=is_transformers_version(">=", "4.57.1"),
+ reason="Test fails with the latest transformers version",
+ strict=False,
+ )
def test_wuerstchen_prior(self):
device = "cpu"
diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py
index 09df140f1a..a17db3ff0c 100644
--- a/tests/pipelines/test_pipelines.py
+++ b/tests/pipelines/test_pipelines.py
@@ -28,14 +28,15 @@ import warnings
import numpy as np
import PIL.Image
+import pytest
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
+from huggingface_hub.utils import HfHubHTTPError
from parameterized import parameterized
from PIL import Image
-from requests.exceptions import HTTPError
from transformers import CLIPImageProcessor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import (
@@ -62,10 +63,7 @@ from diffusers import (
)
from diffusers.pipelines.pipeline_utils import _get_pipeline_class
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
-from diffusers.utils import (
- CONFIG_NAME,
- WEIGHTS_NAME,
-)
+from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, is_transformers_version
from diffusers.utils.torch_utils import is_compiled_module
from ..testing_utils import (
@@ -430,7 +428,7 @@ class DownloadTests(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# Download this model to make sure it's in the cache.
@@ -457,7 +455,7 @@ class DownloadTests(unittest.TestCase):
response_mock = mock.Mock()
response_mock.status_code = 500
response_mock.headers = {}
- response_mock.raise_for_status.side_effect = HTTPError
+ response_mock.raise_for_status.side_effect = HfHubHTTPError("Server down", response=mock.Mock())
response_mock.json.return_value = {}
# first check that with local files only the pipeline can only be used if cached
@@ -584,6 +582,7 @@ class DownloadTests(unittest.TestCase):
assert not any(f.endswith(unexpected_ext) for f in files)
assert all(variant in f for f in model_files if f.endswith(model_ext) and variant is not None)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_legacy_variants_with_sharded_ckpts_raises_warning(self):
repo_id = "hf-internal-testing/tiny-stable-diffusion-pipe-variants-all-kinds"
logger = logging.get_logger("diffusers.pipelines.pipeline_utils")
@@ -630,6 +629,7 @@ class DownloadTests(unittest.TestCase):
# https://huggingface.co/hf-internal-testing/stable-diffusion-broken-variants/tree/main/unet
assert len(files) == 15, f"We should only download 15 files, not {len(files)}"
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_bin_only_variant_exists_for_model(self):
variant = None
use_safetensors = False
@@ -675,6 +675,7 @@ class DownloadTests(unittest.TestCase):
assert "Could not find the necessary `safetensors` weights" in str(error_context.exception)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_download_bin_variant_does_not_exist_for_model(self):
variant = "no_ema"
use_safetensors = False
@@ -690,6 +691,7 @@ class DownloadTests(unittest.TestCase):
)
assert "Error no file name" in str(error_context.exception)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_local_save_load_index(self):
prompt = "hello"
for variant in [None, "fp16"]:
@@ -1584,6 +1586,7 @@ class PipelineFastTests(unittest.TestCase):
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_no_pytorch_download_when_doing_safetensors(self):
# by default we don't download
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -1603,6 +1606,7 @@ class PipelineFastTests(unittest.TestCase):
# pytorch does not
assert not os.path.exists(os.path.join(path, "diffusion_pytorch_model.bin"))
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_no_safetensors_download_when_doing_pytorch(self):
use_safetensors = False
@@ -1888,6 +1892,7 @@ class PipelineFastTests(unittest.TestCase):
"DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True
)
+ @pytest.mark.xfail(condition=is_transformers_version(">", "4.56.2"), reason="Some import error", strict=False)
def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index db8209835b..2af4ad0314 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -1461,6 +1461,8 @@ class PipelineTesterMixin:
def test_save_load_optional_components(self, expected_max_difference=1e-4):
if not hasattr(self.pipeline_class, "_optional_components"):
return
+ if not self.pipeline_class._optional_components:
+ return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
diff --git a/tests/pipelines/wan/test_wan_vace.py b/tests/pipelines/wan/test_wan_vace.py
index f99863c880..fe078c0deb 100644
--- a/tests/pipelines/wan/test_wan_vace.py
+++ b/tests/pipelines/wan/test_wan_vace.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import tempfile
import unittest
import numpy as np
@@ -19,9 +20,15 @@ import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
-from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
+from diffusers import (
+ AutoencoderKLWan,
+ FlowMatchEulerDiscreteScheduler,
+ UniPCMultistepScheduler,
+ WanVACEPipeline,
+ WanVACETransformer3DModel,
+)
-from ...testing_utils import enable_full_determinism
+from ...testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@@ -212,3 +219,81 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
)
def test_save_load_float16(self):
pass
+
+ def test_inference_with_only_transformer(self):
+ components = self.get_dummy_components()
+ components["transformer_2"] = None
+ components["boundary_ratio"] = 0.0
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ video = pipe(**inputs).frames[0]
+ assert video.shape == (17, 3, 16, 16)
+
+ def test_inference_with_only_transformer_2(self):
+ components = self.get_dummy_components()
+ components["transformer_2"] = components["transformer"]
+ components["transformer"] = None
+
+ # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
+ # because starting timestep t == 1000 == boundary_timestep
+ components["scheduler"] = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
+ )
+
+ components["boundary_ratio"] = 1.0
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ video = pipe(**inputs).frames[0]
+ assert video.shape == (17, 3, 16, 16)
+
+ def test_save_load_optional_components(self, expected_max_difference=1e-4):
+ optional_component = ["transformer"]
+
+ components = self.get_dummy_components()
+ components["transformer_2"] = components["transformer"]
+ # FlowMatchEulerDiscreteScheduler doesn't support running low noise only scheduler
+ # because starting timestep t == 1000 == boundary_timestep
+ components["scheduler"] = UniPCMultistepScheduler(
+ prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0
+ )
+ for component in optional_component:
+ components[component] = None
+
+ components["boundary_ratio"] = 1.0
+
+ pipe = self.pipeline_class(**components)
+ for component in pipe.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ generator_device = "cpu"
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output = pipe(**inputs)[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ pipe.save_pretrained(tmpdir, safe_serialization=False)
+ pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
+ for component in pipe_loaded.components.values():
+ if hasattr(component, "set_default_attn_processor"):
+ component.set_default_attn_processor()
+ pipe_loaded.to(torch_device)
+ pipe_loaded.set_progress_bar_config(disable=None)
+
+ for component in optional_component:
+ assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
+
+ inputs = self.get_dummy_inputs(generator_device)
+ torch.manual_seed(0)
+ output_loaded = pipe_loaded(**inputs)[0]
+
+ max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
+ assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 38322459e7..0f4fd408a7 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -360,33 +360,33 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
{
("xpu", 3): np.array(
[
- 0.1953125,
- 0.3125,
- 0.31445312,
- 0.13085938,
- 0.30664062,
- 0.29296875,
- 0.11523438,
- 0.2890625,
+ 0.16796875,
+ 0.27929688,
0.28320312,
- 0.16601562,
- 0.3046875,
- 0.328125,
- 0.140625,
- 0.31640625,
+ 0.11328125,
+ 0.27539062,
+ 0.26171875,
+ 0.10742188,
+ 0.26367188,
+ 0.26171875,
+ 0.1484375,
+ 0.2734375,
+ 0.296875,
+ 0.13476562,
+ 0.2890625,
+ 0.30078125,
+ 0.1171875,
+ 0.28125,
+ 0.28125,
+ 0.16015625,
+ 0.31445312,
+ 0.30078125,
+ 0.15625,
0.32421875,
- 0.12304688,
- 0.3046875,
- 0.3046875,
- 0.17578125,
- 0.3359375,
- 0.3203125,
- 0.16601562,
- 0.34375,
- 0.31640625,
- 0.15429688,
- 0.328125,
- 0.31054688,
+ 0.296875,
+ 0.14453125,
+ 0.30859375,
+ 0.2890625,
]
),
("cuda", 7): np.array(
diff --git a/tests/quantization/torchao/README.md b/tests/quantization/torchao/README.md
index fadc529e12..373593091a 100644
--- a/tests/quantization/torchao/README.md
+++ b/tests/quantization/torchao/README.md
@@ -29,7 +29,7 @@ The benchmark results for Flux and CogVideoX can be found in [this](https://gith
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
```bash
-HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
+HF_XET_HIGH_PERFORMANCE=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
```
`diffusers-cli`:
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index 920c3a55f5..38997de17b 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -14,11 +14,13 @@
# limitations under the License.
import gc
+import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
+from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -65,6 +67,9 @@ if is_torchao_available():
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import get_model_size_in_bytes
+ if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
+ from torchao.quantization import Int8WeightOnlyConfig
+
@require_torch
@require_torch_accelerator
@@ -522,6 +527,15 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
+ components = self.get_dummy_components(quantization_config)
+ pipe = FluxPipeline(**components).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ _ = pipe(**inputs)
+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@@ -628,6 +642,14 @@ class TorchAoSerializationTest(unittest.TestCase):
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+ @require_torchao_version_greater_or_equal("0.9.0")
+ def test_aobase_config(self):
+ quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
+ expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
+ device = torch_device
+ self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
+ self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+
@require_torchao_version_greater_or_equal("0.7.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
diff --git a/tests/single_file/single_file_testing_utils.py b/tests/single_file/single_file_testing_utils.py
index 3510d3371c..52fd2f5bfc 100644
--- a/tests/single_file/single_file_testing_utils.py
+++ b/tests/single_file/single_file_testing_utils.py
@@ -1,3 +1,4 @@
+import gc
import tempfile
from io import BytesIO
@@ -9,7 +10,10 @@ from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_nam
from diffusers.models.attention_processor import AttnProcessor
from ..testing_utils import (
+ backend_empty_cache,
+ nightly,
numpy_cosine_similarity_distance,
+ require_torch_accelerator,
torch_device,
)
@@ -47,6 +51,93 @@ def download_diffusers_config(repo_id, tmpdir):
return path
+@nightly
+@require_torch_accelerator
+class SingleFileModelTesterMixin:
+ def setup_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def teardown_method(self):
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+ def test_single_file_model_config(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
+ for param_name, param_value in model_single_file.config.items():
+ if param_name in PARAMS_TO_IGNORE:
+ continue
+ assert model.config[param_name] == param_value, (
+ f"{param_name} differs between pretrained loading and single file loading"
+ )
+
+ def test_single_file_model_parameters(self):
+ pretrained_kwargs = {}
+ single_file_kwargs = {}
+
+ if hasattr(self, "subfolder") and self.subfolder:
+ pretrained_kwargs["subfolder"] = self.subfolder
+
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ pretrained_kwargs["torch_dtype"] = self.torch_dtype
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs)
+ model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs)
+
+ state_dict = model.state_dict()
+ state_dict_single_file = model_single_file.state_dict()
+
+ assert set(state_dict.keys()) == set(state_dict_single_file.keys()), (
+ "Model parameters keys differ between pretrained and single file loading"
+ )
+
+ for key in state_dict.keys():
+ param = state_dict[key]
+ param_single_file = state_dict_single_file[key]
+
+ assert param.shape == param_single_file.shape, (
+ f"Parameter shape mismatch for {key}: "
+ f"pretrained {param.shape} vs single file {param_single_file.shape}"
+ )
+
+ assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), (
+ f"Parameter values differ for {key}: "
+ f"max difference {torch.max(torch.abs(param - param_single_file)).item()}"
+ )
+
+ def test_checkpoint_altered_keys_loading(self):
+ # Test loading with checkpoints that have altered keys
+ if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths:
+ return
+
+ for ckpt_path in self.alternate_keys_ckpt_paths:
+ backend_empty_cache(torch_device)
+
+ single_file_kwargs = {}
+ if hasattr(self, "torch_dtype") and self.torch_dtype:
+ single_file_kwargs["torch_dtype"] = self.torch_dtype
+
+ model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs)
+
+ del model
+ gc.collect()
+ backend_empty_cache(torch_device)
+
+
class SDSingleFileTesterMixin:
single_file_kwargs = {}
diff --git a/tests/single_file/test_lumina2_transformer.py b/tests/single_file/test_lumina2_transformer.py
index 99d9b71395..bb5a0bf473 100644
--- a/tests/single_file/test_lumina2_transformer.py
+++ b/tests/single_file/test_lumina2_transformer.py
@@ -13,26 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
from diffusers import (
Lumina2Transformer2DModel,
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
+class TestLumina2Transformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = Lumina2Transformer2DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
alternate_keys_ckpt_paths = [
@@ -40,34 +35,4 @@ class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
+ subfolder = "transformer"
diff --git a/tests/single_file/test_model_autoencoder_dc_single_file.py b/tests/single_file/test_model_autoencoder_dc_single_file.py
index 5195f8e52f..444ca40469 100644
--- a/tests/single_file/test_model_autoencoder_dc_single_file.py
+++ b/tests/single_file/test_model_autoencoder_dc_single_file.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
@@ -23,38 +21,24 @@ from diffusers import (
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderDCSingleFileTests(unittest.TestCase):
+class TestAutoencoderDCSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderDC
ckpt_path = "https://huggingface.co/mit-han-lab/dc-ae-f32c32-sana-1.0/blob/main/model.safetensors"
repo_id = "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers"
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -80,18 +64,6 @@ class AutoencoderDCSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
-
def test_single_file_in_type_variant_components(self):
# `in` variant checkpoints require passing in a `config` parameter
# in order to set the scaling factor correctly.
diff --git a/tests/single_file/test_model_controlnet_single_file.py b/tests/single_file/test_model_controlnet_single_file.py
index e5214fe3f2..2fa81fe3ae 100644
--- a/tests/single_file/test_model_controlnet_single_file.py
+++ b/tests/single_file/test_model_controlnet_single_file.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
@@ -23,46 +21,19 @@ from diffusers import (
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class ControlNetModelSingleFileTests(unittest.TestCase):
+class TestControlNetModelSingleFile(SingleFileModelTesterMixin):
model_class = ControlNetModel
ckpt_path = "https://huggingface.co/lllyasviel/ControlNet-v1-1/blob/main/control_v11p_sd15_canny.pth"
repo_id = "lllyasviel/control_v11p_sd15_canny"
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path)
diff --git a/tests/single_file/test_model_flux_transformer_single_file.py b/tests/single_file/test_model_flux_transformer_single_file.py
index 8290c339b9..0642a71c57 100644
--- a/tests/single_file/test_model_flux_transformer_single_file.py
+++ b/tests/single_file/test_model_flux_transformer_single_file.py
@@ -14,7 +14,6 @@
# limitations under the License.
import gc
-import unittest
from diffusers import (
FluxTransformer2DModel,
@@ -23,52 +22,21 @@ from diffusers import (
from ..testing_utils import (
backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class FluxTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestFluxTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = FluxTransformer2DModel
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
alternate_keys_ckpt_paths = ["https://huggingface.co/Comfy-Org/flux1-dev/blob/main/flux1-dev-fp8.safetensors"]
repo_id = "black-forest-labs/FLUX.1-dev"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
+ subfolder = "transformer"
def test_device_map_cuda(self):
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_model_motion_adapter_single_file.py b/tests/single_file/test_model_motion_adapter_single_file.py
index 7aaf4b577e..a047c81b47 100644
--- a/tests/single_file/test_model_motion_adapter_single_file.py
+++ b/tests/single_file/test_model_motion_adapter_single_file.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import unittest
from diffusers import (
MotionAdapter,
@@ -27,7 +26,7 @@ from ..testing_utils import (
enable_full_determinism()
-class MotionAdapterSingleFileTests(unittest.TestCase):
+class MotionAdapterSingleFileTests:
model_class = MotionAdapter
def test_single_file_components_version_v1_5(self):
diff --git a/tests/single_file/test_model_sd_cascade_unet_single_file.py b/tests/single_file/test_model_sd_cascade_unet_single_file.py
index a5ec9dba30..7472122710 100644
--- a/tests/single_file/test_model_sd_cascade_unet_single_file.py
+++ b/tests/single_file/test_model_sd_cascade_unet_single_file.py
@@ -14,7 +14,6 @@
# limitations under the License.
import gc
-import unittest
import torch
@@ -37,14 +36,12 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableCascadeUNetSingleFileTest(unittest.TestCase):
- def setUp(self):
- super().setUp()
+class StableCascadeUNetSingleFileTest:
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_model_vae_single_file.py b/tests/single_file/test_model_vae_single_file.py
index 3b9e619f13..9198d9b163 100644
--- a/tests/single_file/test_model_vae_single_file.py
+++ b/tests/single_file/test_model_vae_single_file.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
@@ -23,22 +21,18 @@ from diffusers import (
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
load_hf_numpy,
numpy_cosine_similarity_distance,
- require_torch_accelerator,
- slow,
torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@slow
-@require_torch_accelerator
-class AutoencoderKLSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKL
ckpt_path = (
"https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"
@@ -47,16 +41,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
main_input_name = "sample"
base_precision = 1e-2
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
@@ -84,18 +68,6 @@ class AutoencoderKLSingleFileTests(unittest.TestCase):
assert numpy_cosine_similarity_distance(output_slice_1, output_slice_2) < 1e-4
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between pretrained loading and single file loading"
- )
-
def test_single_file_arguments(self):
model_default = self.model_class.from_single_file(self.ckpt_path, config=self.repo_id)
diff --git a/tests/single_file/test_model_wan_autoencoder_single_file.py b/tests/single_file/test_model_wan_autoencoder_single_file.py
index a1f7155c10..0babf30234 100644
--- a/tests/single_file/test_model_wan_autoencoder_single_file.py
+++ b/tests/single_file/test_model_wan_autoencoder_single_file.py
@@ -13,50 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
from diffusers import (
AutoencoderKLWan,
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class AutoencoderKLWanSingleFileTests(unittest.TestCase):
+class TestAutoencoderKLWanSingleFile(SingleFileModelTesterMixin):
model_class = AutoencoderKLWan
ckpt_path = (
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="vae")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "vae"
diff --git a/tests/single_file/test_model_wan_transformer3d_single_file.py b/tests/single_file/test_model_wan_transformer3d_single_file.py
index d7c758d3d9..b769092060 100644
--- a/tests/single_file/test_model_wan_transformer3d_single_file.py
+++ b/tests/single_file/test_model_wan_transformer3d_single_file.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import gc
-import unittest
import torch
@@ -23,72 +21,26 @@ from diffusers import (
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
require_big_accelerator,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class WanTransformer3DModelText2VideoSingleFileTest(unittest.TestCase):
+class TestWanTransformer3DModelText2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors"
repo_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "transformer"
@require_big_accelerator
-@require_torch_accelerator
-class WanTransformer3DModelImage2VideoSingleFileTest(unittest.TestCase):
+class TestWanTransformer3DModelImage2VideoSingleFile(SingleFileModelTesterMixin):
model_class = WanTransformer3DModel
ckpt_path = "https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_i2v_480p_14B_fp8_e4m3fn.safetensors"
repo_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
torch_dtype = torch.float8_e4m3fn
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer", torch_dtype=self.torch_dtype)
- model_single_file = self.model_class.from_single_file(self.ckpt_path, torch_dtype=self.torch_dtype)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
+ subfolder = "transformer"
diff --git a/tests/single_file/test_sana_transformer.py b/tests/single_file/test_sana_transformer.py
index c1543ba171..9e2adb93bf 100644
--- a/tests/single_file/test_sana_transformer.py
+++ b/tests/single_file/test_sana_transformer.py
@@ -1,23 +1,17 @@
-import gc
-import unittest
-
from diffusers import (
SanaTransformer2DModel,
)
from ..testing_utils import (
- backend_empty_cache,
enable_full_determinism,
- require_torch_accelerator,
- torch_device,
)
+from .single_file_testing_utils import SingleFileModelTesterMixin
enable_full_determinism()
-@require_torch_accelerator
-class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
+class TestSanaTransformer2DModelSingleFile(SingleFileModelTesterMixin):
model_class = SanaTransformer2DModel
ckpt_path = (
"https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px/blob/main/checkpoints/Sana_1600M_1024px.pth"
@@ -27,34 +21,4 @@ class SanaTransformer2DModelSingleFileTests(unittest.TestCase):
]
repo_id = "Efficient-Large-Model/Sana_1600M_1024px_diffusers"
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_single_file_components(self):
- model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
- model_single_file = self.model_class.from_single_file(self.ckpt_path)
-
- PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
- for param_name, param_value in model_single_file.config.items():
- if param_name in PARAMS_TO_IGNORE:
- continue
- assert model.config[param_name] == param_value, (
- f"{param_name} differs between single file loading and pretrained loading"
- )
-
- def test_checkpoint_loading(self):
- for ckpt_path in self.alternate_keys_ckpt_paths:
- backend_empty_cache(torch_device)
- model = self.model_class.from_single_file(ckpt_path)
-
- del model
- gc.collect()
- backend_empty_cache(torch_device)
+ subfolder = "transformer"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
index e558eeaf6f..141748b084 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_img2img_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -29,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
index 54224f51a9..8238866cbf 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_inpaint_single_file.py
@@ -1,7 +1,7 @@
import gc
import tempfile
-import unittest
+import pytest
import torch
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
@@ -29,19 +29,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "stable-diffusion-v1-5/stable-diffusion-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -115,7 +113,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_canny", variant="fp16")
pipe = self.pipeline_class.from_pretrained(self.repo_id, controlnet=controlnet)
@@ -125,7 +123,7 @@ class StableDiffusionControlNetInpaintPipelineSingleFileSlowTests(unittest.TestC
super()._compare_component_configs(pipe, pipe_single_file)
- @unittest.skip("runwayml original config repo does not exist")
+ @pytest.mark.skip(reason="runwayml original config repo does not exist")
def test_single_file_components_with_original_config_local_files_only(self):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16, variant="fp16"
diff --git a/tests/single_file/test_stable_diffusion_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
index e90e648a9d..80ef6c2574 100644
--- a/tests/single_file/test_stable_diffusion_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_controlnet_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -29,7 +28,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionControlNetPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionControlNetPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -39,13 +38,11 @@ class StableDiffusionControlNetPipelineSingleFileSlowTests(unittest.TestCase, SD
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_img2img_single_file.py b/tests/single_file/test_stable_diffusion_img2img_single_file.py
index 387f09471d..e76846c800 100644
--- a/tests/single_file/test_stable_diffusion_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -23,7 +22,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionImg2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -33,13 +32,11 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,19 +63,17 @@ class StableDiffusionImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSin
@slow
@require_torch_accelerator
-class StableDiffusion21Img2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21Img2ImgPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_inpaint_single_file.py b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
index 84636ec0f0..6e5d27cdff 100644
--- a/tests/single_file/test_stable_diffusion_inpaint_single_file.py
+++ b/tests/single_file/test_stable_diffusion_inpaint_single_file.py
@@ -1,6 +1,6 @@
import gc
-import unittest
+import pytest
import torch
from diffusers import (
@@ -23,19 +23,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = "https://huggingface.co/botp/stable-diffusion-v1-5-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
original_config = "https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml"
repo_id = "botp/stable-diffusion-v1-5-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -70,18 +68,18 @@ class StableDiffusionInpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSin
assert pipe.unet.config.in_channels == 4
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config(self):
return
- @unittest.skip("runwayml original config has been removed")
+ @pytest.mark.skip(reason="runwayml original config has been removed")
def test_single_file_components_with_original_config_local_files_only(self):
return
@slow
@require_torch_accelerator
-class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21InpaintPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInpaintPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/512-inpainting-ema.safetensors"
@@ -89,13 +87,11 @@ class StableDiffusion21InpaintPipelineSingleFileSlowTests(unittest.TestCase, SDS
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inpainting-inference.yaml"
repo_id = "stabilityai/stable-diffusion-2-inpainting"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_single_file.py b/tests/single_file/test_stable_diffusion_single_file.py
index 4601b75c3a..377dedbc57 100644
--- a/tests/single_file/test_stable_diffusion_single_file.py
+++ b/tests/single_file/test_stable_diffusion_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -28,7 +27,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = (
"https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.safetensors"
@@ -38,13 +37,11 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
)
repo_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -90,19 +87,17 @@ class StableDiffusionPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFile
@slow
-class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusion21PipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/v2-1_768-ema-pruned.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
repo_id = "stabilityai/stable-diffusion-2-1"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -125,7 +120,7 @@ class StableDiffusion21PipelineSingleFileSlowTests(unittest.TestCase, SDSingleFi
@nightly
@slow
@require_torch_accelerator
-class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionInstructPix2PixPipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors"
original_config = (
@@ -134,13 +129,11 @@ class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCas
repo_id = "timbrooks/instruct-pix2pix"
single_file_kwargs = {"extract_ema": True}
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_upscale_single_file.py b/tests/single_file/test_stable_diffusion_upscale_single_file.py
index 39ec7b0194..ba4819fadf 100644
--- a/tests/single_file/test_stable_diffusion_upscale_single_file.py
+++ b/tests/single_file/test_stable_diffusion_upscale_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import pytest
import torch
@@ -25,19 +24,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionUpscalePipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin):
+class TestStableDiffusionUpscalePipelineSingleFileSlow(SDSingleFileTesterMixin):
pipeline_class = StableDiffusionUpscalePipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/blob/main/x4-upscaler-ema.safetensors"
original_config = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
repo_id = "stabilityai/stable-diffusion-x4-upscaler"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
index 3de9ee7364..3d124fa8c2 100644
--- a/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_adapter_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -32,7 +31,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLAdapterPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLAdapterPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -40,13 +39,11 @@ class StableDiffusionXLAdapterPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
index a0a1aba103..6f50370261 100644
--- a/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_controlnet_single_file.py
@@ -1,6 +1,5 @@
import gc
import tempfile
-import unittest
import torch
@@ -28,7 +27,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLControlNetPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLControlNetPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -36,13 +35,11 @@ class StableDiffusionXLControlNetPipelineSingleFileSlowTests(unittest.TestCase,
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
index 810f412f8d..56657f37d9 100644
--- a/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_img2img_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -25,7 +24,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLImg2ImgPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -33,13 +32,11 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@@ -66,7 +63,7 @@ class StableDiffusionXLImg2ImgPipelineSingleFileSlowTests(unittest.TestCase, SDX
@slow
@require_torch_accelerator
-class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests(unittest.TestCase):
+class StableDiffusionXLImg2ImgRefinerPipelineSingleFileSlowTests:
pipeline_class = StableDiffusionXLImg2ImgPipeline
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/blob/main/sd_xl_refiner_1.0.safetensors"
diff --git a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
index 011d59222a..d755b70105 100644
--- a/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
+++ b/tests/single_file/test_stable_diffusion_xl_instruct_pix2pix.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -19,19 +18,17 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLInstructPix2PixPipeline(unittest.TestCase):
+class StableDiffusionXLInstructPix2PixPipeline:
pipeline_class = StableDiffusionXLInstructPix2PixPipeline
ckpt_path = "https://huggingface.co/stabilityai/cosxl/blob/main/cosxl_edit.safetensors"
original_config = None
repo_id = "diffusers/sdxl-instructpix2pix-768"
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/single_file/test_stable_diffusion_xl_single_file.py b/tests/single_file/test_stable_diffusion_xl_single_file.py
index 0ad180de17..4e5319ca25 100644
--- a/tests/single_file/test_stable_diffusion_xl_single_file.py
+++ b/tests/single_file/test_stable_diffusion_xl_single_file.py
@@ -1,5 +1,4 @@
import gc
-import unittest
import torch
@@ -22,7 +21,7 @@ enable_full_determinism()
@slow
@require_torch_accelerator
-class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingleFileTesterMixin):
+class TestStableDiffusionXLPipelineSingleFileSlow(SDXLSingleFileTesterMixin):
pipeline_class = StableDiffusionXLPipeline
ckpt_path = "https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
repo_id = "stabilityai/stable-diffusion-xl-base-1.0"
@@ -30,13 +29,11 @@ class StableDiffusionXLPipelineSingleFileSlowTests(unittest.TestCase, SDXLSingle
"https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
)
- def setUp(self):
- super().setUp()
+ def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
- def tearDown(self):
- super().tearDown()
+ def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
diff --git a/tests/testing_utils.py b/tests/testing_utils.py
index 7f849219c1..951ba41280 100644
--- a/tests/testing_utils.py
+++ b/tests/testing_utils.py
@@ -63,6 +63,8 @@ else:
IS_CUDA_SYSTEM = False
IS_XPU_SYSTEM = False
+IS_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" and os.getenv("DIFFUSERS_IS_CI") == "yes"
+
global_rng = random.Random()
logger = get_logger(__name__)