1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into gguf-compile-offload

This commit is contained in:
Sayak Paul
2025-07-09 11:00:49 +05:30
committed by GitHub
50 changed files with 1164 additions and 268 deletions

View File

@@ -248,7 +248,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-m "big_gpu_with_torch_cuda" \
-m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/

View File

@@ -188,7 +188,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
max-parallel: 2
max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:

View File

@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig
[[autodoc]] apply_first_block_cache

View File

@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
pipe.enabe_model_cpu_offload()
pipe.enable_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."

View File

@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
</hfoption>
</hfoptions>
#### LoRA files
#### LoRAs
[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
```py
from diffusers import StableDiffusionXLPipeline
import torch
# base model
pipeline = StableDiffusionXLPipeline.from_pretrained(
"Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
).to("cuda")
# download LoRA weights
!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
# load LoRA weights
pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
image = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=torch.manual_seed(0),
).images[0]
image
```
The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
<div class="flex justify-center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/blueprint-lora.png"/>
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/safetensors_lora.png"/>
</div>
For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
```py
import torch
from diffusers import FluxPipeline
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
pipeline.save_lora_weights(
transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
)
```
### ckpt
> [!WARNING]

View File

@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
## Training Kontext
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
**important**
> [!NOTE]
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
> To do this, execute the following steps in a new virtual environment:
> ```
> git clone https://github.com/huggingface/diffusers
> cd diffusers
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
> pip install -e .
> ```
Below is an example training command:
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
perform as expected.
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
* Condition image
* Target image
* Instruction
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
```bash
accelerate launch train_dreambooth_lora_flux_kontext.py \
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
--output_dir="kontext-i2i" \
--dataset_name="kontext-community/relighting" \
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
--mixed_precision="bf16" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--optimizer="adamw" \
--use_8bit_adam \
--cache_latents \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=200 \
--max_train_steps=1000 \
--rank=16\
--seed="0"
```
More generally, when performing I2I fine-tuning, we expect you to:
* Have a dataset `kontext-community/relighting`
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
## Other notes
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️

View File

@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
from torchvision.transforms.functional import crop
from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
free_memory,
parse_buckets_string,
)
from diffusers.utils import (
check_min_version,
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -186,6 +182,7 @@ def log_validation(
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
pipeline_args_cp = pipeline_args.copy()
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
@@ -193,14 +190,16 @@ def log_validation(
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
prompt = pipeline_args_cp.pop("prompt")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
**pipeline_args_cp,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator,
).images[0]
images.append(image)
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
parser.add_argument(
"--cond_image_column",
type=str,
default=None,
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
)
parser.add_argument(
"--caption_column",
type=str,
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
default="flux-dreambooth-lora",
default="flux-kontext-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
if args.cond_image_column is not None:
raise ValueError("Prior preservation isn't supported with I2I training.")
else:
# logger is not available yet
if args.class_data_dir is not None:
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.cond_image_column is not None:
assert args.image_column is not None
assert args.caption_column is not None
assert args.dataset_name is not None
assert not args.train_text_encoder
if args.validation_prompt is not None:
assert args.validation_image is None and os.path.exists(args.validation_image)
return args
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
repeats=1,
center_crop=False,
buckets=None,
args=None,
):
self.center_crop = center_crop
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.cond_image_column is not None and args.cond_image_column not in column_names:
raise ValueError(
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
instance_images = dataset["train"][image_column]
instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
cond_images = None
cond_image_column = args.cond_image_column
if cond_image_column is not None:
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.instance_images = []
for img in instance_images:
self.cond_images = []
for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
if args.dataset_name is not None and cond_images is not None:
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
for image in self.instance_images:
self.cond_pixel_values = []
for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
dest_image = None
if self.cond_images:
dest_image = exif_transpose(self.cond_images[i])
if not dest_image.mode == "RGB":
dest_image = dest_image.convert("RGB")
width, height = image.size
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
image, dest_image = self.paired_transform(
image,
dest_image=dest_image,
size=self.size,
center_crop=args.center_crop,
random_flip=args.random_flip,
)
image = train_resize(image)
if args.center_crop:
image = train_crop(image)
else:
y1, x1, h, w = train_crop.get_params(image, self.size)
image = crop(image, y1, x1, h, w)
if args.random_flip and random.random() < 0.5:
image = train_flip(image)
image = train_transforms(image)
self.pixel_values.append((image, bucket_idx))
if dest_image is not None:
self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.cond_pixel_values:
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
return example
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
# 1. Resize (deterministic)
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
image = resize(image)
if dest_image is not None:
dest_image = resize(dest_image)
# 2. Crop: either center or SAME random crop
if center_crop:
crop = transforms.CenterCrop(size)
image = crop(image)
if dest_image is not None:
dest_image = crop(dest_image)
else:
# get_params returns (i, j, h, w)
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
image = TF.crop(image, i, j, h, w)
if dest_image is not None:
dest_image = TF.crop(dest_image, i, j, h, w)
# 3. Random horizontal flip with the SAME coin flip
if random_flip:
do_flip = random.random() < 0.5
if do_flip:
image = TF.hflip(image)
if dest_image is not None:
dest_image = TF.hflip(dest_image)
# 4. ToTensor + Normalize (deterministic)
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize([0.5], [0.5])
image = normalize(to_tensor(image))
if dest_image is not None:
dest_image = normalize(to_tensor(dest_image))
return (image, dest_image) if dest_image is not None else (image, None)
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
if any("cond_images" in example for example in examples):
cond_pixel_values = [example["cond_images"] for example in examples]
cond_pixel_values = torch.stack(cond_pixel_values)
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
batch.update({"cond_pixel_values": cond_pixel_values})
return batch
@@ -1318,6 +1393,7 @@ def main(args):
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
"proj_mlp",
]
# now we will add new LoRA weights the transformer layers
@@ -1534,7 +1610,10 @@ def main(args):
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
args=args,
)
if args.cond_image_column is not None:
logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1574,6 +1653,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
@@ -1605,19 +1685,41 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
cached_text_embeddings = []
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
batch_prompts = batch["prompts"]
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
batch_prompts, text_encoders, tokenizers
)
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
if args.validation_prompt is None:
text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
has_image_input = args.cond_image_column is not None
if args.cache_latents:
latents_cache = []
cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
if has_image_input:
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if args.validation_prompt is None:
vae.cpu()
del vae
free_memory()
@@ -1678,7 +1780,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_name = "dreambooth-flux-dev-lora"
tracker_name = "dreambooth-flux-kontext-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
@@ -1742,6 +1844,7 @@ def main(args):
sigma = sigma.unsqueeze(-1)
return sigma
has_guidance = unwrap_model(transformer).config.guidance_embeds
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
@@ -1759,9 +1862,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
prompts, text_encoders, tokenizers
)
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(
@@ -1794,16 +1895,29 @@ def main(args):
if args.cache_latents:
if args.vae_encode_mode == "sample":
model_input = latents_cache[step].sample()
if has_image_input:
cond_model_input = cond_latents_cache[step].sample()
else:
model_input = latents_cache[step].mode()
if has_image_input:
cond_model_input = cond_latents_cache[step].mode()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
if has_image_input:
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
if args.vae_encode_mode == "sample":
model_input = vae.encode(pixel_values).latent_dist.sample()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
else:
model_input = vae.encode(pixel_values).latent_dist.mode()
if has_image_input:
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
if has_image_input:
cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
cond_model_input = cond_model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
@@ -1814,6 +1928,17 @@ def main(args):
accelerator.device,
weight_dtype,
)
if has_image_input:
cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
cond_model_input.shape[0],
cond_model_input.shape[2] // 2,
cond_model_input.shape[3] // 2,
accelerator.device,
weight_dtype,
)
cond_latents_ids[..., 0] = 1
latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
@@ -1834,7 +1959,6 @@ def main(args):
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
@@ -1842,13 +1966,22 @@ def main(args):
height=model_input.shape[2],
width=model_input.shape[3],
)
orig_inp_shape = packed_noisy_model_input.shape
if has_image_input:
packed_cond_input = FluxKontextPipeline._pack_latents(
cond_model_input,
batch_size=cond_model_input.shape[0],
num_channels_latents=cond_model_input.shape[1],
height=cond_model_input.shape[2],
width=cond_model_input.shape[3],
)
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
# handle guidance
if unwrap_model(transformer).config.guidance_embeds:
# Kontext always has guidance
guidance = None
if has_guidance:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
else:
guidance = None
# Predict the noise residual
model_pred = transformer(
@@ -1862,6 +1995,8 @@ def main(args):
img_ids=latent_image_ids,
return_dict=False,
)[0]
if has_image_input:
model_pred = model_pred[:, : orig_inp_shape[1]]
model_pred = FluxKontextPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
@@ -1970,6 +2105,8 @@ def main(args):
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -2030,6 +2167,8 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
if has_image_input and args.validation_image:
pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,

View File

@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from .models import (

View File

@@ -1,8 +1,23 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook

View File

@@ -0,0 +1,30 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)

View File

@@ -0,0 +1,264 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, Type
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
_cls: Type = None
_cached_parameter_indices: Dict[str, int] = None
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
kwargs = kwargs or {}
if identifier in kwargs:
return kwargs[identifier]
if self._cached_parameter_indices is not None:
return args[self._cached_parameter_indices[identifier]]
if self._cls is None:
raise ValueError("Model class is not set for metadata.")
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
parameters = parameters[1:] # skip `self`
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
if identifier not in self._cached_parameter_indices:
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
index = self._cached_parameter_indices[identifier]
if index >= len(args):
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
return args[index]
class AttentionProcessorRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._register()
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_attention_processors_metadata()
class TransformerBlockRegistry:
_registry = {}
# TODO(aryan): this is only required for the time being because we need to do the registrations
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
# import errors because of the models imported in this file.
_is_registered = False
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._register()
metadata._cls = model_class
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
cls._register()
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
@classmethod
def _register(cls):
if cls._is_registered:
return
cls._is_registered = True
_register_transformer_blocks_metadata()
def _register_attention_processors_metadata():
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
from ..models.attention import BasicTransformerBlock
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
# fmt: on

View File

@@ -0,0 +1,227 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
but might lead to poorer generation quality. A lower threshold may not result in significant generation
speedup. The threshold is compared against the absmean difference of the residuals between the current and
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
is skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
else:
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
encoder_hidden_states = (
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return_output = tuple(return_output)
else:
return_output = hidden_states
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return output
def reset_state(self, module):
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
)
else:
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)

View File

@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseState:
def reset(self, *args, **kwargs) -> None:
raise NotImplementedError(
"BaseState::reset is not implemented. Please implement this method in the derived class."
)
class StateManager:
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
self._state_cls = state_cls
self._init_args = init_args if init_args is not None else ()
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
self._state_cache = {}
self._current_context = None
def get_state(self):
if self._current_context is None:
raise ValueError("No context is set. Please set a context before retrieving the state.")
if self._current_context not in self._state_cache.keys():
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
return self._state_cache[self._current_context]
def set_context(self, name: str) -> None:
self._current_context = name
def reset(self, *args, **kwargs) -> None:
for name, state in list(self._state_cache.items()):
state.reset(*args, **kwargs)
self._state_cache.pop(name)
self._current_context = None
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
def _set_context(self, module: torch.nn.Module, name: str) -> None:
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, StateManager):
attr.set_context(name)
return module
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
for module_name, module in self._module_ref.named_modules():
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
def _set_context(self, name: Optional[str] = None) -> None:
for hook_name in reversed(self._hook_order):
hook = self.hooks[hook_name]
if hook._is_stateful:
hook._set_context(self._module_ref, name)
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(name)
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
if isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
elif isinstance(config, FasterCacheConfig):
if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
elif isinstance(config, FirstBlockCacheConfig):
apply_first_block_cache(self, config)
elif isinstance(config, PyramidAttentionBroadcastConfig):
apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
elif isinstance(self._cache_config, FasterCacheConfig):
registry = HookRegistry.check_if_exists_or_initialize(self)
registry = HookRegistry.check_if_exists_or_initialize(self)
if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, FirstBlockCacheConfig):
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@contextmanager
def cache_context(self, name: str):
r"""Context manager that provides additional methods for cache management."""
from ..hooks import HookRegistry
registry = HookRegistry.check_if_exists_or_initialize(self)
registry._set_context(name)
yield
registry._set_context(None)

View File

@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()

View File

@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
return hidden_states
encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
hidden_states = block(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)

View File

@@ -22,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module):
return freqs_cos, freqs_sin
@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,

View File

@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
ofs=ofs_emb,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
# perform guidance

View File

@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond

View File

@@ -912,32 +912,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=negative_text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
neg_noise_pred = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if do_true_cfg:
with self.transformer.cache_context("uncond"):
neg_noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_attention_mask=negative_prompt_attention_mask,
pooled_projections=negative_pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
video_coords=video_coords,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

View File

@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
rope_interpolation_scale=rope_interpolation_scale,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:

View File

@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
with self.transformer.cache_context("cond_uncond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)

View File

@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class FirstBlockCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
def apply_first_block_cache(*args, **kwargs):
requires_backends(apply_first_block_cache, ["torch"])
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])

View File

@@ -421,6 +421,10 @@ def require_big_accelerator(test_case):
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
Flux, SD3, Cog, etc.
"""
import pytest
test_case = pytest.mark.big_accelerator(test_case)
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)

View File

@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).

View File

@@ -30,6 +30,10 @@ sys.path.insert(1, git_repo_path)
warnings.simplefilter(action="ignore", category=FutureWarning)
def pytest_configure(config):
config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared

View File

@@ -20,7 +20,6 @@ import tempfile
import unittest
import numpy as np
import pytest
import safetensors.torch
import torch
from parameterized import parameterized
@@ -813,7 +812,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -960,7 +958,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0

View File

@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -198,7 +197,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.

View File

@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -139,7 +138,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -45,7 +46,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}

View File

@@ -17,7 +17,6 @@ import gc
import unittest
import numpy as np
import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -211,7 +210,6 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline

View File

@@ -18,7 +18,6 @@ import unittest
from typing import Optional
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -221,7 +220,6 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline

View File

@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -25,6 +24,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +34,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -224,7 +225,6 @@ class FluxPipelineFastTests(
@nightly
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,6 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"

View File

@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
@@ -19,7 +18,6 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"

View File

@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
FirstBlockCacheTesterMixin,
unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])

View File

@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
num_layers=1,
num_layers=num_layers,
caption_channels=32,
)

View File

@@ -17,7 +17,6 @@ import inspect
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -33,13 +32,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
class MochiPipelineFastTests(
PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -268,7 +269,6 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte
@nightly
@require_torch_accelerator
@require_big_accelerator
@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."

View File

@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -233,7 +232,6 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -3,7 +3,6 @@ import random
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -168,7 +167,6 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

View File

@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
output = run_forward(pipe).flatten().flatten()
output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
# of the box once there is better cache support/implementation
class FirstBlockCacheTesterMixin:
# threshold is intentionally set higher than usual values since we're testing with random unconverged models
# that will not satisfy the expected properties of the denoiser for caching to be effective
first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
def test_first_block_cache_inference(self, expected_atol: float = 0.1):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
def create_pipe():
torch.manual_seed(0)
num_layers = 2
components = self.get_dummy_components(num_layers=num_layers)
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
return pipe
def run_forward(pipe):
torch.manual_seed(0)
inputs = self.get_dummy_inputs(device)
inputs["num_inference_steps"] = 4
return pipe(**inputs)[0]
# Run inference without FirstBlockCache
pipe = create_pipe()
output = run_forward(pipe).flatten()
original_image_slice = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache enabled
pipe = create_pipe()
pipe.transformer.enable_cache(self.first_block_cache_config)
output = run_forward(pipe).flatten()
image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FirstBlockCache disabled
pipe.transformer.disable_cache()
output = run_forward(pipe).flatten()
image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
"FirstBlockCache outputs should not differ much."
)
assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
"Outputs from normal inference and after disabling cache should not differ."
)
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.

View File

@@ -872,6 +872,7 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):

View File

@@ -837,6 +837,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):