diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 16e1a70b84..384f07506a 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -248,7 +248,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -m "big_gpu_with_torch_cuda" \
+ -m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 87d5177388..fd0c76c2ff 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -188,7 +188,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
- max-parallel: 2
+ max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md
index e90cb32c54..9ba4742085 100644
--- a/docs/source/en/api/cache.md
+++ b/docs/source/en/api/cache.md
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
+
+### FirstBlockCacheConfig
+
+[[autodoc]] FirstBlockCacheConfig
+
+[[autodoc]] apply_first_block_cache
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
index 4e2d144421..40e290e4bd 100644
--- a/docs/source/en/api/pipelines/chroma.md
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
-pipe.enabe_model_cpu_offload()
+pipe.enable_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index df3df92f06..11afbf29d3 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
-#### LoRA files
+#### LoRAs
-[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
+[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
-LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-# base model
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-# download LoRA weights
-!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
-
-# load LoRA weights
-pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
-prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
-negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
-
-image = pipeline(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.manual_seed(0),
-).images[0]
-image
-```
+The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
-

+
+For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
+
+```py
+import torch
+from diffusers import FluxPipeline
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
+pipeline.save_lora_weights(
+ transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
+ text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
+)
+```
+
### ckpt
> [!WARNING]
diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md
index 24c71d5c56..18273746c2 100644
--- a/examples/dreambooth/README_flux.md
+++ b/examples/dreambooth/README_flux.md
@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
## Training Kontext
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
-provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
+provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
-Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
+**important**
+
+> [!NOTE]
+> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
+> To do this, execute the following steps in a new virtual environment:
+> ```
+> git clone https://github.com/huggingface/diffusers
+> cd diffusers
+> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
+> pip install -e .
+> ```
Below is an example training command:
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
perform as expected.
+Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
+
+* Condition image
+* Target image
+* Instruction
+
+[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
+
+```bash
+accelerate launch train_dreambooth_lora_flux_kontext.py \
+ --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
+ --output_dir="kontext-i2i" \
+ --dataset_name="kontext-community/relighting" \
+ --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --optimizer="adamw" \
+ --use_8bit_adam \
+ --cache_latents \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=200 \
+ --max_train_steps=1000 \
+ --rank=16\
+ --seed="0"
+```
+
+More generally, when performing I2I fine-tuning, we expect you to:
+
+* Have a dataset `kontext-community/relighting`
+* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
+
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
## Other notes
-Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
\ No newline at end of file
+Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
index 9f97567b06..5bd9b8684d 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
-from torchvision.transforms.functional import crop
+from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
free_memory,
parse_buckets_string,
)
-from diffusers.utils import (
- check_min_version,
- convert_unet_state_dict_to_peft,
- is_wandb_available,
-)
+from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -186,6 +182,7 @@ def log_validation(
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
+ pipeline_args_cp = pipeline_args.copy()
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
@@ -193,14 +190,16 @@ def log_validation(
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
- prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
- pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
- )
+ prompt = pipeline_args_cp.pop("prompt")
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
+ **pipeline_args_cp,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ generator=generator,
).images[0]
images.append(image)
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
+ parser.add_argument(
+ "--cond_image_column",
+ type=str,
+ default=None,
+ help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
+ )
parser.add_argument(
"--caption_column",
type=str,
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
- required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
+ )
parser.add_argument(
"--num_validation_images",
type=int,
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
- default="flux-dreambooth-lora",
+ default="flux-kontext-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
+ if args.cond_image_column is not None:
+ raise ValueError("Prior preservation isn't supported with I2I training.")
else:
# logger is not available yet
if args.class_data_dir is not None:
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+ if args.cond_image_column is not None:
+ assert args.image_column is not None
+ assert args.caption_column is not None
+ assert args.dataset_name is not None
+ assert not args.train_text_encoder
+ if args.validation_prompt is not None:
+ assert args.validation_image is None and os.path.exists(args.validation_image)
+
return args
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
repeats=1,
center_crop=False,
buckets=None,
+ args=None,
):
self.center_crop = center_crop
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
+ if args.cond_image_column is not None and args.cond_image_column not in column_names:
+ raise ValueError(
+ f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
- instance_images = dataset["train"][image_column]
+ instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
+ cond_images = None
+ cond_image_column = args.cond_image_column
+ if cond_image_column is not None:
+ cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
+ assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.instance_images = []
- for img in instance_images:
+ self.cond_images = []
+ for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
+ if args.dataset_name is not None and cond_images is not None:
+ self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
- for image in self.instance_images:
+ self.cond_pixel_values = []
+ for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
+ dest_image = None
+ if self.cond_images:
+ dest_image = exif_transpose(self.cond_images[i])
+ if not dest_image.mode == "RGB":
+ dest_image = dest_image.convert("RGB")
width, height = image.size
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
- train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
- train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
- train_flip = transforms.RandomHorizontalFlip(p=1.0)
- train_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
+ image, dest_image = self.paired_transform(
+ image,
+ dest_image=dest_image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
)
- image = train_resize(image)
- if args.center_crop:
- image = train_crop(image)
- else:
- y1, x1, h, w = train_crop.get_params(image, self.size)
- image = crop(image, y1, x1, h, w)
- if args.random_flip and random.random() < 0.5:
- image = train_flip(image)
- image = train_transforms(image)
self.pixel_values.append((image, bucket_idx))
+ if dest_image is not None:
+ self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
+ if self.cond_pixel_values:
+ dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
+ example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
return example
+ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+ if dest_image is not None:
+ dest_image = resize(dest_image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ if dest_image is not None:
+ dest_image = crop(dest_image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+ if dest_image is not None:
+ dest_image = TF.crop(dest_image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+ if dest_image is not None:
+ dest_image = TF.hflip(dest_image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+ if dest_image is not None:
+ dest_image = normalize(to_tensor(dest_image))
+
+ return (image, dest_image) if dest_image is not None else (image, None)
+
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
+ if any("cond_images" in example for example in examples):
+ cond_pixel_values = [example["cond_images"] for example in examples]
+ cond_pixel_values = torch.stack(cond_pixel_values)
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
+ batch.update({"cond_pixel_values": cond_pixel_values})
return batch
@@ -1318,6 +1393,7 @@ def main(args):
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
+ "proj_mlp",
]
# now we will add new LoRA weights the transformer layers
@@ -1534,7 +1610,10 @@ def main(args):
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
+ args=args,
)
+ if args.cond_image_column is not None:
+ logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1574,6 +1653,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
@@ -1605,19 +1685,41 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
+ cached_text_embeddings = []
+ for batch in tqdm(train_dataloader, desc="Embedding prompts"):
+ batch_prompts = batch["prompts"]
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ batch_prompts, text_encoders, tokenizers
+ )
+ cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
+
+ if args.validation_prompt is None:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
+ del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
+ free_memory()
+
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
+ has_image_input = args.cond_image_column is not None
if args.cache_latents:
latents_cache = []
+ cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ if has_image_input:
+ batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if args.validation_prompt is None:
+ vae.cpu()
del vae
free_memory()
@@ -1678,7 +1780,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
- tracker_name = "dreambooth-flux-dev-lora"
+ tracker_name = "dreambooth-flux-kontext-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
@@ -1742,6 +1844,7 @@ def main(args):
sigma = sigma.unsqueeze(-1)
return sigma
+ has_guidance = unwrap_model(transformer).config.guidance_embeds
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
@@ -1759,9 +1862,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
- prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
- prompts, text_encoders, tokenizers
- )
+ prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(
@@ -1794,16 +1895,29 @@ def main(args):
if args.cache_latents:
if args.vae_encode_mode == "sample":
model_input = latents_cache[step].sample()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].sample()
else:
model_input = latents_cache[step].mode()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].mode()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ if has_image_input:
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
if args.vae_encode_mode == "sample":
model_input = vae.encode(pixel_values).latent_dist.sample()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
else:
model_input = vae.encode(pixel_values).latent_dist.mode()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
+ if has_image_input:
+ cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ cond_model_input = cond_model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
@@ -1814,6 +1928,17 @@ def main(args):
accelerator.device,
weight_dtype,
)
+ if has_image_input:
+ cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
+ cond_model_input.shape[0],
+ cond_model_input.shape[2] // 2,
+ cond_model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ cond_latents_ids[..., 0] = 1
+ latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
+
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
@@ -1834,7 +1959,6 @@ def main(args):
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
-
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
@@ -1842,13 +1966,22 @@ def main(args):
height=model_input.shape[2],
width=model_input.shape[3],
)
+ orig_inp_shape = packed_noisy_model_input.shape
+ if has_image_input:
+ packed_cond_input = FluxKontextPipeline._pack_latents(
+ cond_model_input,
+ batch_size=cond_model_input.shape[0],
+ num_channels_latents=cond_model_input.shape[1],
+ height=cond_model_input.shape[2],
+ width=cond_model_input.shape[3],
+ )
+ packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
- # handle guidance
- if unwrap_model(transformer).config.guidance_embeds:
+ # Kontext always has guidance
+ guidance = None
+ if has_guidance:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
- else:
- guidance = None
# Predict the noise residual
model_pred = transformer(
@@ -1862,6 +1995,8 @@ def main(args):
img_ids=latent_image_ids,
return_dict=False,
)[0]
+ if has_image_input:
+ model_pred = model_pred[:, : orig_inp_shape[1]]
model_pred = FluxKontextPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
@@ -1970,6 +2105,8 @@ def main(args):
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -2030,6 +2167,8 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 4c383c817e..713472b4a5 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
+ "FirstBlockCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
+ "apply_first_block_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
+ apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from .models import (
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 764ceb25b4..365bed3718 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -1,8 +1,23 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
+ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py
new file mode 100644
index 0000000000..3be77dd4ce
--- /dev/null
+++ b/src/diffusers/hooks/_common.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..models.attention_processor import Attention, MochiAttention
+
+
+_ATTENTION_CLASSES = (Attention, MochiAttention)
+
+_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
+_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
+_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
+
+_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
+ {
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ }
+)
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
new file mode 100644
index 0000000000..960d14e6fa
--- /dev/null
+++ b/src/diffusers/hooks/_helpers.py
@@ -0,0 +1,264 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Type
+
+
+@dataclass
+class AttentionProcessorMetadata:
+ skip_processor_output_fn: Callable[[Any], Any]
+
+
+@dataclass
+class TransformerBlockMetadata:
+ return_hidden_states_index: int = None
+ return_encoder_hidden_states_index: int = None
+
+ _cls: Type = None
+ _cached_parameter_indices: Dict[str, int] = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if identifier in kwargs:
+ return kwargs[identifier]
+ if self._cached_parameter_indices is not None:
+ return args[self._cached_parameter_indices[identifier]]
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+ if identifier not in self._cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+ index = self._cached_parameter_indices[identifier]
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+ return args[index]
+
+
+class AttentionProcessorRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
+ cls._register()
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_attention_processors_metadata()
+
+
+class TransformerBlockRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
+ cls._register()
+ metadata._cls = model_class
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_transformer_blocks_metadata()
+
+
+def _register_attention_processors_metadata():
+ from ..models.attention_processor import AttnProcessor2_0
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
+
+ # AttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=AttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
+ ),
+ )
+
+ # CogView4AttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=CogView4AttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
+ ),
+ )
+
+
+def _register_transformer_blocks_metadata():
+ from ..models.attention import BasicTransformerBlock
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+ from ..models.transformers.transformer_hunyuan_video import (
+ HunyuanVideoSingleTransformerBlock,
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
+ HunyuanVideoTokenReplaceTransformerBlock,
+ HunyuanVideoTransformerBlock,
+ )
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
+ from ..models.transformers.transformer_wan import WanTransformerBlock
+
+ # BasicTransformerBlock
+ TransformerBlockRegistry.register(
+ model_class=BasicTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # CogVideoX
+ TransformerBlockRegistry.register(
+ model_class=CogVideoXBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # CogView4
+ TransformerBlockRegistry.register(
+ model_class=CogView4TransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Flux
+ TransformerBlockRegistry.register(
+ model_class=FluxTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=FluxSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+
+ # HunyuanVideo
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # LTXVideo
+ TransformerBlockRegistry.register(
+ model_class=LTXVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # Mochi
+ TransformerBlockRegistry.register(
+ model_class=MochiTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Wan
+ TransformerBlockRegistry.register(
+ model_class=WanTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+
+# fmt: off
+def _skip_attention___ret___hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ return hidden_states
+
+
+def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ if encoder_hidden_states is None and len(args) > 1:
+ encoder_hidden_states = args[1]
+ return hidden_states, encoder_hidden_states
+
+
+_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
+# fmt: on
diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py
new file mode 100644
index 0000000000..40ae8c5a26
--- /dev/null
+++ b/src/diffusers/hooks/first_block_cache.py
@@ -0,0 +1,227 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
+from ._helpers import TransformerBlockRegistry
+from .hooks import BaseState, HookRegistry, ModelHook, StateManager
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
+_FBC_BLOCK_HOOK = "fbc_block_hook"
+
+
+@dataclass
+class FirstBlockCacheConfig:
+ r"""
+ Configuration for [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
+
+ Args:
+ threshold (`float`, defaults to `0.05`):
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
+ is skipped.
+ """
+
+ threshold: float = 0.05
+
+
+class FBCSharedBlockState(BaseState):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.head_block_residual: torch.Tensor = None
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.should_compute: bool = True
+
+ def reset(self):
+ self.tail_block_residuals = None
+ self.should_compute = True
+
+
+class FBCHeadBlockHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(self, state_manager: StateManager, threshold: float):
+ self.state_manager = state_manager
+ self.threshold = threshold
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ is_output_tuple = isinstance(output, tuple)
+
+ if is_output_tuple:
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
+ else:
+ hidden_states_residual = output - original_hidden_states
+
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
+ hidden_states = encoder_hidden_states = None
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
+ shared_state.should_compute = should_compute
+
+ if not should_compute:
+ # Apply caching
+ if is_output_tuple:
+ hidden_states = (
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
+ )
+ else:
+ hidden_states = shared_state.tail_block_residuals[0] + output
+
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ assert is_output_tuple
+ encoder_hidden_states = (
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
+ )
+
+ if is_output_tuple:
+ return_output = [None] * len(output)
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
+ return_output = tuple(return_output)
+ else:
+ return_output = hidden_states
+ output = return_output
+ else:
+ if is_output_tuple:
+ head_block_output = [None] * len(output)
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
+ else:
+ head_block_output = output
+ shared_state.head_block_output = head_block_output
+ shared_state.head_block_residual = hidden_states_residual
+
+ return output
+
+ def reset_state(self, module):
+ self.state_manager.reset()
+ return module
+
+ @torch.compiler.disable
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
+ shared_state = self.state_manager.get_state()
+ if shared_state.head_block_residual is None:
+ return True
+ prev_hidden_states_residual = shared_state.head_block_residual
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
+ diff = (absmean / prev_hidden_states_absmean).item()
+ return diff > self.threshold
+
+
+class FBCBlockHook(ModelHook):
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
+ super().__init__()
+ self.state_manager = state_manager
+ self.is_tail = is_tail
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ original_encoder_hidden_states = None
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+
+ shared_state = self.state_manager.get_state()
+
+ if shared_state.should_compute:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ if self.is_tail:
+ hidden_states_residual = encoder_hidden_states_residual = None
+ if isinstance(output, tuple):
+ hidden_states_residual = (
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
+ )
+ encoder_hidden_states_residual = (
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
+ )
+ else:
+ hidden_states_residual = output - shared_state.head_block_output
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
+ return output
+
+ if original_encoder_hidden_states is None:
+ return_output = original_hidden_states
+ else:
+ return_output = [None, None]
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
+ return_output = tuple(return_output)
+ return return_output
+
+
+def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
+ state_manager = StateManager(FBCSharedBlockState, (), {})
+ remaining_blocks = []
+
+ for name, submodule in module.named_children():
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
+ continue
+ for index, block in enumerate(submodule):
+ remaining_blocks.append((f"{name}.{index}", block))
+
+ head_block_name, head_block = remaining_blocks.pop(0)
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
+
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
+
+ for name, block in remaining_blocks:
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
+ _apply_fbc_block_hook(block, state_manager)
+
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
+
+
+def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCHeadBlockHook(state_manager, threshold)
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
+
+
+def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCBlockHook(state_manager, is_tail)
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py
index 96231aadc3..6e097e5882 100644
--- a/src/diffusers/hooks/hooks.py
+++ b/src/diffusers/hooks/hooks.py
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
+from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
+class BaseState:
+ def reset(self, *args, **kwargs) -> None:
+ raise NotImplementedError(
+ "BaseState::reset is not implemented. Please implement this method in the derived class."
+ )
+
+
+class StateManager:
+ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
+ self._state_cls = state_cls
+ self._init_args = init_args if init_args is not None else ()
+ self._init_kwargs = init_kwargs if init_kwargs is not None else {}
+ self._state_cache = {}
+ self._current_context = None
+
+ def get_state(self):
+ if self._current_context is None:
+ raise ValueError("No context is set. Please set a context before retrieving the state.")
+ if self._current_context not in self._state_cache.keys():
+ self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
+ return self._state_cache[self._current_context]
+
+ def set_context(self, name: str) -> None:
+ self._current_context = name
+
+ def reset(self, *args, **kwargs) -> None:
+ for name, state in list(self._state_cache.items()):
+ state.reset(*args, **kwargs)
+ self._state_cache.pop(name)
+ self._current_context = None
+
+
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
+ def _set_context(self, module: torch.nn.Module, name: str) -> None:
+ # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
+ for attr_name in dir(self):
+ attr = getattr(self, attr_name)
+ if isinstance(attr, StateManager):
+ attr.set_context(name)
+ return module
+
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
- for module_name, module in self._module_ref.named_modules():
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
+ module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
+ def _set_context(self, name: Optional[str] = None) -> None:
+ for hook_name in reversed(self._hook_order):
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook._set_context(self._module_ref, name)
+
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
+ if module_name == "":
+ continue
+ module = unwrap_module(module)
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook._set_context(name)
+
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py
index 3fd1ca6e9d..605c0d588c 100644
--- a/src/diffusers/models/cache_utils.py
+++ b/src/diffusers/models/cache_utils.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from contextlib import contextmanager
+
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
+ - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
+ apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
- if isinstance(config, PyramidAttentionBroadcastConfig):
- apply_pyramid_attention_broadcast(self, config)
- elif isinstance(config, FasterCacheConfig):
+ if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
+ elif isinstance(config, FirstBlockCacheConfig):
+ apply_first_block_cache(self, config)
+ elif isinstance(config, PyramidAttentionBroadcastConfig):
+ apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
+ from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
- elif isinstance(self._cache_config, FasterCacheConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
+
+ @contextmanager
+ def cache_context(self, name: str):
+ r"""Context manager that provides additional methods for cache management."""
+ from ..hooks import HookRegistry
+
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry._set_context(name)
+
+ yield
+
+ registry._set_context(None)
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
index d8e99ee45e..063ff5bd8e 100644
--- a/src/diffusers/models/controlnets/controlnet_flux.py
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+ single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index e4144d0c8e..dc45befb98 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
+@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 3af1de2ad0..3a7202d0f4 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
- return hidden_states
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
- + controlnet_single_block_samples[index_block // interval_control]
- )
-
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index 5fb71b69f7..bdb9201e62 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -22,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
+@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index f08a3c35c2..3c5994172c 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index fe3e8ae388..cf6ccebc47 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index a982f4b275..d1f02ca9c9 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- ofs=ofs_emb,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ ofs=ofs_emb,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 7c50bdcb7d..230c8ca296 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 880253459e..d8374b694f 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
- noise_pred_cond = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- original_size=original_size,
- target_size=target_size,
- crop_coords=crops_coords_top_left,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- # perform guidance
- if self.do_classifier_free_guidance:
- noise_pred_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 4c83ae7405..073d94750a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -912,32 +912,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- if negative_image_embeds is not None:
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
- pooled_projections=negative_pooled_prompt_embeds,
- encoder_hidden_states=negative_prompt_embeds,
- txt_ids=negative_text_ids,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index b617e4f8b2..2cbb4af2b4 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=prompt_attention_mask,
- pooled_projections=pooled_prompt_embeds,
- guidance=guidance,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
- encoder_attention_mask=negative_prompt_attention_mask,
- pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 3b58b4a45a..77ba751700 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index fa9ee4fc7b..217478f418 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- video_coords=video_coords,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 99412b6962..8793d81377 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index 7712b41524..3c0f908296 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
index 6df66118b0..d14dac91f1 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 2981f3a420..6d25047a0f 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class FirstBlockCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
+def apply_first_block_cache(*args, **kwargs):
+ requires_backends(apply_first_block_cache, ["torch"])
+
+
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index e5da39c1d8..ebb3d70553 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -421,6 +421,10 @@ def require_big_accelerator(test_case):
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
Flux, SD3, Cog, etc.
"""
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index ffc1119727..61a5d95b69 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+def unwrap_module(module):
+ """Unwraps a module if it was compiled with torch.compile()"""
+ return module._orig_mod if is_compiled_module(module) else module
+
+
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
diff --git a/tests/conftest.py b/tests/conftest.py
index 7e9c4e8f39..3237fb9c7b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,6 +30,10 @@ sys.path.insert(1, git_repo_path)
warnings.simplefilter(action="ignore", category=FutureWarning)
+def pytest_configure(config):
+ config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
+
+
def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 336ac2246f..95f1e137e9 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -20,7 +20,6 @@ import tempfile
import unittest
import numpy as np
-import pytest
import safetensors.torch
import torch
from parameterized import parameterized
@@ -813,7 +812,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -960,7 +958,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index 19e31f320d..4cbd6523e7 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -198,7 +197,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 8a8f2a676d..8928ccbac2 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -139,7 +138,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index c725589781..a6cb558513 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -45,7 +46,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 5ee94b09ba..5b336edc7a 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -17,7 +17,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -211,7 +210,6 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 712c26b0a2..1f1f800bcf 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -18,7 +18,6 @@ import unittest
from typing import Optional
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -221,7 +220,6 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index cbdf617d71..0df0e028ff 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -25,6 +24,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +34,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
- unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -224,7 +225,6 @@ class FluxPipelineFastTests(
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,6 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
index b8f36dfd3c..b73050a64d 100644
--- a/tests/pipelines/flux/test_pipeline_flux_redux.py
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
@@ -19,7 +18,6 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
index ecc5eba964..10101af75c 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
index 1d1eb08234..bf0c7fde59 100644
--- a/tests/pipelines/ltx/test_ltx.py
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
- num_layers=1,
+ num_layers=num_layers,
caption_channels=32,
)
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index 5b00261b06..f1684cce72 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -17,7 +17,6 @@ import inspect
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -33,13 +32,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
+from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
+class MochiPipelineFastTests(
+ PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
+):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -268,7 +269,6 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte
@nightly
@require_torch_accelerator
@require_big_accelerator
-@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 577ac4ebdd..2179ec8e22 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -233,7 +232,6 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index f5b5e63a81..7f913cb63d 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -3,7 +3,6 @@ import random
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -168,7 +167,6 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index f87778b260..13c25ccaa4 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
+from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
- output = run_forward(pipe).flatten().flatten()
+ output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
+# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
+# of the box once there is better cache support/implementation
+class FirstBlockCacheTesterMixin:
+ # threshold is intentionally set higher than usual values since we're testing with random unconverged models
+ # that will not satisfy the expected properties of the denoiser for caching to be effective
+ first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
+
+ def test_first_block_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ return pipe(**inputs)[0]
+
+ # Run inference without FirstBlockCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache enabled
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.first_block_cache_config)
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
+ "FirstBlockCache outputs should not differ much."
+ )
+ assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
+
+
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index 06116cac3a..98005cfbc8 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -872,6 +872,7 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
+@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 2ea4cdfde8..f3bbc34e8b 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -837,6 +837,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
+@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):