diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml index 16e1a70b84..384f07506a 100644 --- a/.github/workflows/nightly_tests.yml +++ b/.github/workflows/nightly_tests.yml @@ -248,7 +248,7 @@ jobs: BIG_GPU_MEMORY: 40 run: | python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \ - -m "big_gpu_with_torch_cuda" \ + -m "big_accelerator" \ --make-reports=tests_big_gpu_torch_cuda \ --report-log=tests_big_gpu_torch_cuda.log \ tests/ diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml index 87d5177388..fd0c76c2ff 100644 --- a/.github/workflows/pr_tests_gpu.yml +++ b/.github/workflows/pr_tests_gpu.yml @@ -188,7 +188,7 @@ jobs: shell: bash strategy: fail-fast: false - max-parallel: 2 + max-parallel: 4 matrix: module: [models, schedulers, lora, others] steps: diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md index e90cb32c54..9ba4742085 100644 --- a/docs/source/en/api/cache.md +++ b/docs/source/en/api/cache.md @@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate [[autodoc]] FasterCacheConfig [[autodoc]] apply_faster_cache + +### FirstBlockCacheConfig + +[[autodoc]] FirstBlockCacheConfig + +[[autodoc]] apply_first_block_cache diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md index 4e2d144421..40e290e4bd 100644 --- a/docs/source/en/api/pipelines/chroma.md +++ b/docs/source/en/api/pipelines/chroma.md @@ -36,7 +36,7 @@ import torch from diffusers import ChromaPipeline pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16) -pipe.enabe_model_cpu_offload() +pipe.enable_model_cpu_offload() prompt = [ "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done." diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md index df3df92f06..11afbf29d3 100644 --- a/docs/source/en/using-diffusers/other-formats.md +++ b/docs/source/en/using-diffusers/other-formats.md @@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file( -#### LoRA files +#### LoRAs -[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/). +[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations. -LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method. - -```py -from diffusers import StableDiffusionXLPipeline -import torch - -# base model -pipeline = StableDiffusionXLPipeline.from_pretrained( - "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16" -).to("cuda") - -# download LoRA weights -!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors - -# load LoRA weights -pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors") -prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop" -negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture" - -image = pipeline( - prompt=prompt, - negative_prompt=negative_prompt, - generator=torch.manual_seed(0), -).images[0] -image -``` +The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
- +
+For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file. + +```py +import torch +from diffusers import FluxPipeline + +pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 +).to("cuda") +pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA") +pipeline.save_lora_weights( + transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16}, + text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8} +) +``` + ### ckpt > [!WARNING] diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md index 24c71d5c56..18273746c2 100644 --- a/examples/dreambooth/README_flux.md +++ b/examples/dreambooth/README_flux.md @@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N ## Training Kontext [Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We -provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too. +provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too. -Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section. +**important** + +> [!NOTE] +> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below. +> To do this, execute the following steps in a new virtual environment: +> ``` +> git clone https://github.com/huggingface/diffusers +> cd diffusers +> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b +> pip install -e . +> ``` Below is an example training command: @@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \ Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not perform as expected. +Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets: + +* Condition image +* Target image +* Instruction + +[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training: + +```bash +accelerate launch train_dreambooth_lora_flux_kontext.py \ + --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \ + --output_dir="kontext-i2i" \ + --dataset_name="kontext-community/relighting" \ + --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \ + --mixed_precision="bf16" \ + --resolution=1024 \ + --train_batch_size=1 \ + --guidance_scale=1 \ + --gradient_accumulation_steps=4 \ + --gradient_checkpointing \ + --optimizer="adamw" \ + --use_8bit_adam \ + --cache_latents \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=200 \ + --max_train_steps=1000 \ + --rank=16\ + --seed="0" +``` + +More generally, when performing I2I fine-tuning, we expect you to: + +* Have a dataset `kontext-community/relighting` +* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training + ### Misc notes * By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it. @@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗 ## Other notes -Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ \ No newline at end of file +Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️ diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index 9f97567b06..5bd9b8684d 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torch.utils.data.sampler import BatchSampler from torchvision import transforms -from torchvision.transforms.functional import crop +from torchvision.transforms import functional as TF from tqdm.auto import tqdm from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast @@ -62,11 +62,7 @@ from diffusers.training_utils import ( free_memory, parse_buckets_string, ) -from diffusers.utils import ( - check_min_version, - convert_unet_state_dict_to_peft, - is_wandb_available, -) +from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import is_compiled_module @@ -186,6 +182,7 @@ def log_validation( ) pipeline = pipeline.to(accelerator.device, dtype=torch_dtype) pipeline.set_progress_bar_config(disable=True) + pipeline_args_cp = pipeline_args.copy() # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None @@ -193,14 +190,16 @@ def log_validation( # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast with torch.no_grad(): - prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt( - pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] - ) + prompt = pipeline_args_cp.pop("prompt") + prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None) images = [] for _ in range(args.num_validation_images): with autocast_ctx: image = pipeline( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + **pipeline_args_cp, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator, ).images[0] images.append(image) @@ -310,6 +309,12 @@ def parse_args(input_args=None): "default, the standard Image Dataset maps out 'file_name' " "to 'image'.", ) + parser.add_argument( + "--cond_image_column", + type=str, + default=None, + help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning", + ) parser.add_argument( "--caption_column", type=str, @@ -330,7 +335,6 @@ def parse_args(input_args=None): "--instance_prompt", type=str, default=None, - required=True, help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( @@ -351,6 +355,12 @@ def parse_args(input_args=None): default=None, help="A prompt that is used during validation to verify that the model is learning.", ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.", + ) parser.add_argument( "--num_validation_images", type=int, @@ -399,7 +409,7 @@ def parse_args(input_args=None): parser.add_argument( "--output_dir", type=str, - default="flux-dreambooth-lora", + default="flux-kontext-lora", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") @@ -716,6 +726,8 @@ def parse_args(input_args=None): raise ValueError("You must specify a data directory for class images.") if args.class_prompt is None: raise ValueError("You must specify prompt for class images.") + if args.cond_image_column is not None: + raise ValueError("Prior preservation isn't supported with I2I training.") else: # logger is not available yet if args.class_data_dir is not None: @@ -723,6 +735,14 @@ def parse_args(input_args=None): if args.class_prompt is not None: warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + if args.cond_image_column is not None: + assert args.image_column is not None + assert args.caption_column is not None + assert args.dataset_name is not None + assert not args.train_text_encoder + if args.validation_prompt is not None: + assert args.validation_image is None and os.path.exists(args.validation_image) + return args @@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset): repeats=1, center_crop=False, buckets=None, + args=None, ): self.center_crop = center_crop @@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset): column_names = dataset["train"].column_names # 6. Get the column names for input/target. + if args.cond_image_column is not None and args.cond_image_column not in column_names: + raise ValueError( + f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) if args.image_column is None: image_column = column_names[0] logger.info(f"image column defaulting to {image_column}") @@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset): raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - instance_images = dataset["train"][image_column] + instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))] + cond_images = None + cond_image_column = args.cond_image_column + if cond_image_column is not None: + cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))] + assert len(instance_images) == len(cond_images) if args.caption_column is None: logger.info( @@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset): self.custom_instance_prompts = None self.instance_images = [] - for img in instance_images: + self.cond_images = [] + for i, img in enumerate(instance_images): self.instance_images.extend(itertools.repeat(img, repeats)) + if args.dataset_name is not None and cond_images is not None: + self.cond_images.extend(itertools.repeat(cond_images[i], repeats)) self.pixel_values = [] - for image in self.instance_images: + self.cond_pixel_values = [] + for i, image in enumerate(self.instance_images): image = exif_transpose(image) if not image.mode == "RGB": image = image.convert("RGB") + dest_image = None + if self.cond_images: + dest_image = exif_transpose(self.cond_images[i]) + if not dest_image.mode == "RGB": + dest_image = dest_image.convert("RGB") width, height = image.size @@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset): self.size = (target_height, target_width) # based on the bucket assignment, define the transformations - train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR) - train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size) - train_flip = transforms.RandomHorizontalFlip(p=1.0) - train_transforms = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), - ] + image, dest_image = self.paired_transform( + image, + dest_image=dest_image, + size=self.size, + center_crop=args.center_crop, + random_flip=args.random_flip, ) - image = train_resize(image) - if args.center_crop: - image = train_crop(image) - else: - y1, x1, h, w = train_crop.get_params(image, self.size) - image = crop(image, y1, x1, h, w) - if args.random_flip and random.random() < 0.5: - image = train_flip(image) - image = train_transforms(image) self.pixel_values.append((image, bucket_idx)) + if dest_image is not None: + self.cond_pixel_values.append((dest_image, bucket_idx)) self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset): instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images] example["instance_images"] = instance_image example["bucket_idx"] = bucket_idx + if self.cond_pixel_values: + dest_image, _ = self.cond_pixel_values[index % self.num_instance_images] + example["cond_images"] = dest_image if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset): return example + def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False): + # 1. Resize (deterministic) + resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + image = resize(image) + if dest_image is not None: + dest_image = resize(dest_image) + + # 2. Crop: either center or SAME random crop + if center_crop: + crop = transforms.CenterCrop(size) + image = crop(image) + if dest_image is not None: + dest_image = crop(dest_image) + else: + # get_params returns (i, j, h, w) + i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size) + image = TF.crop(image, i, j, h, w) + if dest_image is not None: + dest_image = TF.crop(dest_image, i, j, h, w) + + # 3. Random horizontal flip with the SAME coin flip + if random_flip: + do_flip = random.random() < 0.5 + if do_flip: + image = TF.hflip(image) + if dest_image is not None: + dest_image = TF.hflip(dest_image) + + # 4. ToTensor + Normalize (deterministic) + to_tensor = transforms.ToTensor() + normalize = transforms.Normalize([0.5], [0.5]) + image = normalize(to_tensor(image)) + if dest_image is not None: + dest_image = normalize(to_tensor(dest_image)) + + return (image, dest_image) if dest_image is not None else (image, None) + def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] @@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() batch = {"pixel_values": pixel_values, "prompts": prompts} + if any("cond_images" in example for example in examples): + cond_pixel_values = [example["cond_images"] for example in examples] + cond_pixel_values = torch.stack(cond_pixel_values) + cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float() + batch.update({"cond_pixel_values": cond_pixel_values}) return batch @@ -1318,6 +1393,7 @@ def main(args): "ff.net.2", "ff_context.net.0.proj", "ff_context.net.2", + "proj_mlp", ] # now we will add new LoRA weights the transformer layers @@ -1534,7 +1610,10 @@ def main(args): buckets=buckets, repeats=args.repeats, center_crop=args.center_crop, + args=args, ) + if args.cond_image_column is not None: + logger.info("I2I fine-tuning enabled.") batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False) train_dataloader = torch.utils.data.DataLoader( train_dataset, @@ -1574,6 +1653,7 @@ def main(args): # Clear the memory here if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + text_encoder_one.cpu(), text_encoder_two.cpu() del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two free_memory() @@ -1605,19 +1685,41 @@ def main(args): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + elif train_dataset.custom_instance_prompts and not args.train_text_encoder: + cached_text_embeddings = [] + for batch in tqdm(train_dataloader, desc="Embedding prompts"): + batch_prompts = batch["prompts"] + prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( + batch_prompts, text_encoders, tokenizers + ) + cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids)) + + if args.validation_prompt is None: + text_encoder_one.cpu(), text_encoder_two.cpu() + del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two + free_memory() + vae_config_shift_factor = vae.config.shift_factor vae_config_scaling_factor = vae.config.scaling_factor vae_config_block_out_channels = vae.config.block_out_channels + has_image_input = args.cond_image_column is not None if args.cache_latents: latents_cache = [] + cond_latents_cache = [] for batch in tqdm(train_dataloader, desc="Caching latents"): with torch.no_grad(): batch["pixel_values"] = batch["pixel_values"].to( accelerator.device, non_blocking=True, dtype=weight_dtype ) latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist) + if has_image_input: + batch["cond_pixel_values"] = batch["cond_pixel_values"].to( + accelerator.device, non_blocking=True, dtype=weight_dtype + ) + cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist) if args.validation_prompt is None: + vae.cpu() del vae free_memory() @@ -1678,7 +1780,7 @@ def main(args): # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: - tracker_name = "dreambooth-flux-dev-lora" + tracker_name = "dreambooth-flux-kontext-lora" accelerator.init_trackers(tracker_name, config=vars(args)) # Train! @@ -1742,6 +1844,7 @@ def main(args): sigma = sigma.unsqueeze(-1) return sigma + has_guidance = unwrap_model(transformer).config.guidance_embeds for epoch in range(first_epoch, args.num_train_epochs): transformer.train() if args.train_text_encoder: @@ -1759,9 +1862,7 @@ def main(args): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if not args.train_text_encoder: - prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings( - prompts, text_encoders, tokenizers - ) + prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step] else: tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77) tokens_two = tokenize_prompt( @@ -1794,16 +1895,29 @@ def main(args): if args.cache_latents: if args.vae_encode_mode == "sample": model_input = latents_cache[step].sample() + if has_image_input: + cond_model_input = cond_latents_cache[step].sample() else: model_input = latents_cache[step].mode() + if has_image_input: + cond_model_input = cond_latents_cache[step].mode() else: pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + if has_image_input: + cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype) if args.vae_encode_mode == "sample": model_input = vae.encode(pixel_values).latent_dist.sample() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample() else: model_input = vae.encode(pixel_values).latent_dist.mode() + if has_image_input: + cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode() model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor model_input = model_input.to(dtype=weight_dtype) + if has_image_input: + cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor + cond_model_input = cond_model_input.to(dtype=weight_dtype) vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1) @@ -1814,6 +1928,17 @@ def main(args): accelerator.device, weight_dtype, ) + if has_image_input: + cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids( + cond_model_input.shape[0], + cond_model_input.shape[2] // 2, + cond_model_input.shape[3] // 2, + accelerator.device, + weight_dtype, + ) + cond_latents_ids[..., 0] = 1 + latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0) + # Sample noise that we'll add to the latents noise = torch.randn_like(model_input) bsz = model_input.shape[0] @@ -1834,7 +1959,6 @@ def main(args): # zt = (1 - texp) * x + texp * z1 sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise - packed_noisy_model_input = FluxKontextPipeline._pack_latents( noisy_model_input, batch_size=model_input.shape[0], @@ -1842,13 +1966,22 @@ def main(args): height=model_input.shape[2], width=model_input.shape[3], ) + orig_inp_shape = packed_noisy_model_input.shape + if has_image_input: + packed_cond_input = FluxKontextPipeline._pack_latents( + cond_model_input, + batch_size=cond_model_input.shape[0], + num_channels_latents=cond_model_input.shape[1], + height=cond_model_input.shape[2], + width=cond_model_input.shape[3], + ) + packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1) - # handle guidance - if unwrap_model(transformer).config.guidance_embeds: + # Kontext always has guidance + guidance = None + if has_guidance: guidance = torch.tensor([args.guidance_scale], device=accelerator.device) guidance = guidance.expand(model_input.shape[0]) - else: - guidance = None # Predict the noise residual model_pred = transformer( @@ -1862,6 +1995,8 @@ def main(args): img_ids=latent_image_ids, return_dict=False, )[0] + if has_image_input: + model_pred = model_pred[:, : orig_inp_shape[1]] model_pred = FluxKontextPipeline._unpack_latents( model_pred, height=model_input.shape[2] * vae_scale_factor, @@ -1970,6 +2105,8 @@ def main(args): torch_dtype=weight_dtype, ) pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args, @@ -2030,6 +2167,8 @@ def main(args): images = [] if args.validation_prompt and args.num_validation_images > 0: pipeline_args = {"prompt": args.validation_prompt} + if has_image_input and args.validation_image: + pipeline_args.update({"image": load_image(args.validation_image)}) images = log_validation( pipeline=pipeline, args=args, diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 4c383c817e..713472b4a5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -133,9 +133,11 @@ else: _import_structure["hooks"].extend( [ "FasterCacheConfig", + "FirstBlockCacheConfig", "HookRegistry", "PyramidAttentionBroadcastConfig", "apply_faster_cache", + "apply_first_block_cache", "apply_pyramid_attention_broadcast", ] ) @@ -751,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) from .models import ( diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py index 764ceb25b4..365bed3718 100644 --- a/src/diffusers/hooks/__init__.py +++ b/src/diffusers/hooks/__init__.py @@ -1,8 +1,23 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ..utils import is_torch_available if is_torch_available(): from .faster_cache import FasterCacheConfig, apply_faster_cache + from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache from .group_offloading import apply_group_offloading from .hooks import HookRegistry, ModelHook from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py new file mode 100644 index 0000000000..3be77dd4ce --- /dev/null +++ b/src/diffusers/hooks/_common.py @@ -0,0 +1,30 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..models.attention_processor import Attention, MochiAttention + + +_ATTENTION_CLASSES = (Attention, MochiAttention) + +_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers") +_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",) +_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers") + +_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple( + { + *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS, + *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS, + } +) diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py new file mode 100644 index 0000000000..960d14e6fa --- /dev/null +++ b/src/diffusers/hooks/_helpers.py @@ -0,0 +1,264 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from dataclasses import dataclass +from typing import Any, Callable, Dict, Type + + +@dataclass +class AttentionProcessorMetadata: + skip_processor_output_fn: Callable[[Any], Any] + + +@dataclass +class TransformerBlockMetadata: + return_hidden_states_index: int = None + return_encoder_hidden_states_index: int = None + + _cls: Type = None + _cached_parameter_indices: Dict[str, int] = None + + def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None): + kwargs = kwargs or {} + if identifier in kwargs: + return kwargs[identifier] + if self._cached_parameter_indices is not None: + return args[self._cached_parameter_indices[identifier]] + if self._cls is None: + raise ValueError("Model class is not set for metadata.") + parameters = list(inspect.signature(self._cls.forward).parameters.keys()) + parameters = parameters[1:] # skip `self` + self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)} + if identifier not in self._cached_parameter_indices: + raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.") + index = self._cached_parameter_indices[identifier] + if index >= len(args): + raise ValueError(f"Expected {index} arguments but got {len(args)}.") + return args[index] + + +class AttentionProcessorRegistry: + _registry = {} + # TODO(aryan): this is only required for the time being because we need to do the registrations + # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular + # import errors because of the models imported in this file. + _is_registered = False + + @classmethod + def register(cls, model_class: Type, metadata: AttentionProcessorMetadata): + cls._register() + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> AttentionProcessorMetadata: + cls._register() + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + @classmethod + def _register(cls): + if cls._is_registered: + return + cls._is_registered = True + _register_attention_processors_metadata() + + +class TransformerBlockRegistry: + _registry = {} + # TODO(aryan): this is only required for the time being because we need to do the registrations + # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular + # import errors because of the models imported in this file. + _is_registered = False + + @classmethod + def register(cls, model_class: Type, metadata: TransformerBlockMetadata): + cls._register() + metadata._cls = model_class + cls._registry[model_class] = metadata + + @classmethod + def get(cls, model_class: Type) -> TransformerBlockMetadata: + cls._register() + if model_class not in cls._registry: + raise ValueError(f"Model class {model_class} not registered.") + return cls._registry[model_class] + + @classmethod + def _register(cls): + if cls._is_registered: + return + cls._is_registered = True + _register_transformer_blocks_metadata() + + +def _register_attention_processors_metadata(): + from ..models.attention_processor import AttnProcessor2_0 + from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor + + # AttnProcessor2_0 + AttentionProcessorRegistry.register( + model_class=AttnProcessor2_0, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0, + ), + ) + + # CogView4AttnProcessor + AttentionProcessorRegistry.register( + model_class=CogView4AttnProcessor, + metadata=AttentionProcessorMetadata( + skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor, + ), + ) + + +def _register_transformer_blocks_metadata(): + from ..models.attention import BasicTransformerBlock + from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock + from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock + from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock + from ..models.transformers.transformer_hunyuan_video import ( + HunyuanVideoSingleTransformerBlock, + HunyuanVideoTokenReplaceSingleTransformerBlock, + HunyuanVideoTokenReplaceTransformerBlock, + HunyuanVideoTransformerBlock, + ) + from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock + from ..models.transformers.transformer_mochi import MochiTransformerBlock + from ..models.transformers.transformer_wan import WanTransformerBlock + + # BasicTransformerBlock + TransformerBlockRegistry.register( + model_class=BasicTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # CogVideoX + TransformerBlockRegistry.register( + model_class=CogVideoXBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # CogView4 + TransformerBlockRegistry.register( + model_class=CogView4TransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Flux + TransformerBlockRegistry.register( + model_class=FluxTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + TransformerBlockRegistry.register( + model_class=FluxSingleTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=1, + return_encoder_hidden_states_index=0, + ), + ) + + # HunyuanVideo + TransformerBlockRegistry.register( + model_class=HunyuanVideoTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoSingleTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + TransformerBlockRegistry.register( + model_class=HunyuanVideoTokenReplaceSingleTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # LTXVideo + TransformerBlockRegistry.register( + model_class=LTXVideoTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + # Mochi + TransformerBlockRegistry.register( + model_class=MochiTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=1, + ), + ) + + # Wan + TransformerBlockRegistry.register( + model_class=WanTransformerBlock, + metadata=TransformerBlockMetadata( + return_hidden_states_index=0, + return_encoder_hidden_states_index=None, + ), + ) + + +# fmt: off +def _skip_attention___ret___hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + return hidden_states + + +def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs): + hidden_states = kwargs.get("hidden_states", None) + encoder_hidden_states = kwargs.get("encoder_hidden_states", None) + if hidden_states is None and len(args) > 0: + hidden_states = args[0] + if encoder_hidden_states is None and len(args) > 1: + encoder_hidden_states = args[1] + return hidden_states, encoder_hidden_states + + +_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states +_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states +# fmt: on diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py new file mode 100644 index 0000000000..40ae8c5a26 --- /dev/null +++ b/src/diffusers/hooks/first_block_cache.py @@ -0,0 +1,227 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Tuple, Union + +import torch + +from ..utils import get_logger +from ..utils.torch_utils import unwrap_module +from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS +from ._helpers import TransformerBlockRegistry +from .hooks import BaseState, HookRegistry, ModelHook, StateManager + + +logger = get_logger(__name__) # pylint: disable=invalid-name + +_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" +_FBC_BLOCK_HOOK = "fbc_block_hook" + + +@dataclass +class FirstBlockCacheConfig: + r""" + Configuration for [First Block + Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). + + Args: + threshold (`float`, defaults to `0.05`): + The threshold to determine whether or not a forward pass through all layers of the model is required. A + higher threshold usually results in a forward pass through a lower number of layers and faster inference, + but might lead to poorer generation quality. A lower threshold may not result in significant generation + speedup. The threshold is compared against the absmean difference of the residuals between the current and + cached outputs from the first transformer block. If the difference is below the threshold, the forward pass + is skipped. + """ + + threshold: float = 0.05 + + +class FBCSharedBlockState(BaseState): + def __init__(self) -> None: + super().__init__() + + self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.head_block_residual: torch.Tensor = None + self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None + self.should_compute: bool = True + + def reset(self): + self.tail_block_residuals = None + self.should_compute = True + + +class FBCHeadBlockHook(ModelHook): + _is_stateful = True + + def __init__(self, state_manager: StateManager, threshold: float): + self.state_manager = state_manager + self.threshold = threshold + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + + output = self.fn_ref.original_forward(*args, **kwargs) + is_output_tuple = isinstance(output, tuple) + + if is_output_tuple: + hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states + else: + hidden_states_residual = output - original_hidden_states + + shared_state: FBCSharedBlockState = self.state_manager.get_state() + hidden_states = encoder_hidden_states = None + should_compute = self._should_compute_remaining_blocks(hidden_states_residual) + shared_state.should_compute = should_compute + + if not should_compute: + # Apply caching + if is_output_tuple: + hidden_states = ( + shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] + ) + else: + hidden_states = shared_state.tail_block_residuals[0] + output + + if self._metadata.return_encoder_hidden_states_index is not None: + assert is_output_tuple + encoder_hidden_states = ( + shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] + ) + + if is_output_tuple: + return_output = [None] * len(output) + return_output[self._metadata.return_hidden_states_index] = hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states + return_output = tuple(return_output) + else: + return_output = hidden_states + output = return_output + else: + if is_output_tuple: + head_block_output = [None] * len(output) + head_block_output[0] = output[self._metadata.return_hidden_states_index] + head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] + else: + head_block_output = output + shared_state.head_block_output = head_block_output + shared_state.head_block_residual = hidden_states_residual + + return output + + def reset_state(self, module): + self.state_manager.reset() + return module + + @torch.compiler.disable + def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: + shared_state = self.state_manager.get_state() + if shared_state.head_block_residual is None: + return True + prev_hidden_states_residual = shared_state.head_block_residual + absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() + prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() + diff = (absmean / prev_hidden_states_absmean).item() + return diff > self.threshold + + +class FBCBlockHook(ModelHook): + def __init__(self, state_manager: StateManager, is_tail: bool = False): + super().__init__() + self.state_manager = state_manager + self.is_tail = is_tail + self._metadata = None + + def initialize_hook(self, module): + unwrapped_module = unwrap_module(module) + self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) + return module + + def new_forward(self, module: torch.nn.Module, *args, **kwargs): + original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) + original_encoder_hidden_states = None + if self._metadata.return_encoder_hidden_states_index is not None: + original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( + "encoder_hidden_states", args, kwargs + ) + + shared_state = self.state_manager.get_state() + + if shared_state.should_compute: + output = self.fn_ref.original_forward(*args, **kwargs) + if self.is_tail: + hidden_states_residual = encoder_hidden_states_residual = None + if isinstance(output, tuple): + hidden_states_residual = ( + output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] + ) + encoder_hidden_states_residual = ( + output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] + ) + else: + hidden_states_residual = output - shared_state.head_block_output + shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) + return output + + if original_encoder_hidden_states is None: + return_output = original_hidden_states + else: + return_output = [None, None] + return_output[self._metadata.return_hidden_states_index] = original_hidden_states + return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states + return_output = tuple(return_output) + return return_output + + +def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: + state_manager = StateManager(FBCSharedBlockState, (), {}) + remaining_blocks = [] + + for name, submodule in module.named_children(): + if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): + continue + for index, block in enumerate(submodule): + remaining_blocks.append((f"{name}.{index}", block)) + + head_block_name, head_block = remaining_blocks.pop(0) + tail_block_name, tail_block = remaining_blocks.pop(-1) + + logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") + _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) + + for name, block in remaining_blocks: + logger.debug(f"Applying FBCBlockHook to '{name}'") + _apply_fbc_block_hook(block, state_manager) + + logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") + _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) + + +def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCHeadBlockHook(state_manager, threshold) + registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) + + +def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None: + registry = HookRegistry.check_if_exists_or_initialize(block) + hook = FBCBlockHook(state_manager, is_tail) + registry.register_hook(hook, _FBC_BLOCK_HOOK) diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py index 96231aadc3..6e097e5882 100644 --- a/src/diffusers/hooks/hooks.py +++ b/src/diffusers/hooks/hooks.py @@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple import torch from ..utils.logging import get_logger +from ..utils.torch_utils import unwrap_module logger = get_logger(__name__) # pylint: disable=invalid-name +class BaseState: + def reset(self, *args, **kwargs) -> None: + raise NotImplementedError( + "BaseState::reset is not implemented. Please implement this method in the derived class." + ) + + +class StateManager: + def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None): + self._state_cls = state_cls + self._init_args = init_args if init_args is not None else () + self._init_kwargs = init_kwargs if init_kwargs is not None else {} + self._state_cache = {} + self._current_context = None + + def get_state(self): + if self._current_context is None: + raise ValueError("No context is set. Please set a context before retrieving the state.") + if self._current_context not in self._state_cache.keys(): + self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs) + return self._state_cache[self._current_context] + + def set_context(self, name: str) -> None: + self._current_context = name + + def reset(self, *args, **kwargs) -> None: + for name, state in list(self._state_cache.items()): + state.reset(*args, **kwargs) + self._state_cache.pop(name) + self._current_context = None + + class ModelHook: r""" A hook that contains callbacks to be executed just before and after the forward method of a model. @@ -99,6 +132,14 @@ class ModelHook: raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.") return module + def _set_context(self, module: torch.nn.Module, name: str) -> None: + # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them. + for attr_name in dir(self): + attr = getattr(self, attr_name) + if isinstance(attr, StateManager): + attr.set_context(name) + return module + class HookFunctionReference: def __init__(self) -> None: @@ -211,9 +252,10 @@ class HookRegistry: hook.reset_state(self._module_ref) if recurse: - for module_name, module in self._module_ref.named_modules(): + for module_name, module in unwrap_module(self._module_ref).named_modules(): if module_name == "": continue + module = unwrap_module(module) if hasattr(module, "_diffusers_hook"): module._diffusers_hook.reset_stateful_hooks(recurse=False) @@ -223,6 +265,19 @@ class HookRegistry: module._diffusers_hook = cls(module) return module._diffusers_hook + def _set_context(self, name: Optional[str] = None) -> None: + for hook_name in reversed(self._hook_order): + hook = self.hooks[hook_name] + if hook._is_stateful: + hook._set_context(self._module_ref, name) + + for module_name, module in unwrap_module(self._module_ref).named_modules(): + if module_name == "": + continue + module = unwrap_module(module) + if hasattr(module, "_diffusers_hook"): + module._diffusers_hook._set_context(name) + def __repr__(self) -> str: registry_repr = "" for i, hook_name in enumerate(self._hook_order): diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py index 3fd1ca6e9d..605c0d588c 100644 --- a/src/diffusers/models/cache_utils.py +++ b/src/diffusers/models/cache_utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + from ..utils.logging import get_logger @@ -25,6 +27,7 @@ class CacheMixin: Supported caching techniques: - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) - [FasterCache](https://huggingface.co/papers/2410.19355) + - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) """ _cache_config = None @@ -62,8 +65,10 @@ class CacheMixin: from ..hooks import ( FasterCacheConfig, + FirstBlockCacheConfig, PyramidAttentionBroadcastConfig, apply_faster_cache, + apply_first_block_cache, apply_pyramid_attention_broadcast, ) @@ -72,31 +77,36 @@ class CacheMixin: f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first." ) - if isinstance(config, PyramidAttentionBroadcastConfig): - apply_pyramid_attention_broadcast(self, config) - elif isinstance(config, FasterCacheConfig): + if isinstance(config, FasterCacheConfig): apply_faster_cache(self, config) + elif isinstance(config, FirstBlockCacheConfig): + apply_first_block_cache(self, config) + elif isinstance(config, PyramidAttentionBroadcastConfig): + apply_pyramid_attention_broadcast(self, config) else: raise ValueError(f"Cache config {type(config)} is not supported.") self._cache_config = config def disable_cache(self) -> None: - from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig + from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK + from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK if self._cache_config is None: logger.warning("Caching techniques have not been enabled, so there's nothing to disable.") return - if isinstance(self._cache_config, PyramidAttentionBroadcastConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) - registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) - elif isinstance(self._cache_config, FasterCacheConfig): - registry = HookRegistry.check_if_exists_or_initialize(self) + registry = HookRegistry.check_if_exists_or_initialize(self) + if isinstance(self._cache_config, FasterCacheConfig): registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True) registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, FirstBlockCacheConfig): + registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True) + registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True) + elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig): + registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True) else: raise ValueError(f"Cache config {type(self._cache_config)} is not supported.") @@ -106,3 +116,15 @@ class CacheMixin: from ..hooks import HookRegistry HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse) + + @contextmanager + def cache_context(self, name: str): + r"""Context manager that provides additional methods for cache management.""" + from ..hooks import HookRegistry + + registry = HookRegistry.check_if_exists_or_initialize(self) + registry._set_context(name) + + yield + + registry._set_context(None) diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py index d8e99ee45e..063ff5bd8e 100644 --- a/src/diffusers/models/controlnets/controlnet_flux.py +++ b/src/diffusers/models/controlnets/controlnet_flux.py @@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ) block_samples = block_samples + (hidden_states,) - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) - single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) + single_block_samples = single_block_samples + (hidden_states,) # controlnet block controlnet_block_samples = () diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py index e4144d0c8e..dc45befb98 100644 --- a/src/diffusers/models/transformers/transformer_cogview4.py +++ b/src/diffusers/models/transformers/transformer_cogview4.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin @@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor: return hidden_states, encoder_hidden_states +@maybe_allow_in_graph class CogView4TransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 3af1de2ad0..3a7202d0f4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module): def forward( self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, ) -> torch.Tensor: + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module): if hidden_states.dtype == torch.float16: hidden_states = hidden_states.clip(-65504, 65504) - return hidden_states + encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] + return encoder_hidden_states, hidden_states @maybe_allow_in_graph @@ -507,20 +512,21 @@ class FluxTransformer2DModel( ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] - hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func( + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( block, hidden_states, + encoder_hidden_states, temb, image_rotary_emb, ) else: - hidden_states = block( + encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, @@ -530,12 +536,7 @@ class FluxTransformer2DModel( if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) - hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( - hidden_states[:, encoder_hidden_states.shape[1] :, ...] - + controlnet_single_block_samples[index_block // interval_control] - ) - - hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index 5fb71b69f7..bdb9201e62 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -22,6 +22,7 @@ import torch.nn.functional as F from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph from ..attention import FeedForward from ..attention_processor import Attention from ..cache_utils import CacheMixin @@ -249,6 +250,7 @@ class WanRotaryPosEmbed(nn.Module): return freqs_cos, freqs_sin +@maybe_allow_in_graph class WanTransformerBlock(nn.Module): def __init__( self, diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index f08a3c35c2..3c5994172c 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py index fe3e8ae388..cf6ccebc47 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py @@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py index a982f4b275..d1f02ca9c9 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py @@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - ofs=ofs_emb, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + ofs=ofs_emb, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py index 7c50bdcb7d..230c8ca296 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py @@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py index 880253459e..d8374b694f 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py @@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, @@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): return_dict=False, )[0] + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 4c83ae7405..073d94750a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -912,32 +912,35 @@ class FluxPipeline( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py index b617e4f8b2..2cbb4af2b4 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py @@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py index 3b58b4a45a..77ba751700 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py @@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py index fa9ee4fc7b..217478f418 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py @@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL if is_conditioning_image_or_video: timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - video_coords=video_coords, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + video_coords=video_coords, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py index 99412b6962..8793d81377 100644 --- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py +++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py @@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py index 7712b41524..3c0f908296 100644 --- a/src/diffusers/pipelines/mochi/pipeline_mochi.py +++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py @@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin): # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond_uncond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # Mochi CFG + Sampling runs in FP32 noise_pred = noise_pred.to(torch.float32) diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py index 6df66118b0..d14dac91f1 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan.py +++ b/src/diffusers/pipelines/wan/pipeline_wan.py @@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin): latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 2981f3a420..6d25047a0f 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject): requires_backends(cls, ["torch"]) +class FirstBlockCacheConfig(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class HookRegistry(metaclass=DummyObject): _backends = ["torch"] @@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs): requires_backends(apply_faster_cache, ["torch"]) +def apply_first_block_cache(*args, **kwargs): + requires_backends(apply_first_block_cache, ["torch"]) + + def apply_pyramid_attention_broadcast(*args, **kwargs): requires_backends(apply_pyramid_attention_broadcast, ["torch"]) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index e5da39c1d8..ebb3d70553 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -421,6 +421,10 @@ def require_big_accelerator(test_case): Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines: Flux, SD3, Cog, etc. """ + import pytest + + test_case = pytest.mark.big_accelerator(test_case) + if not is_torch_available(): return unittest.skip("test requires PyTorch")(test_case) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index ffc1119727..61a5d95b69 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool: return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) +def unwrap_module(module): + """Unwraps a module if it was compiled with torch.compile()""" + return module._orig_mod if is_compiled_module(module) else module + + def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor": """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497). diff --git a/tests/conftest.py b/tests/conftest.py index 7e9c4e8f39..3237fb9c7b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,10 @@ sys.path.insert(1, git_repo_path) warnings.simplefilter(action="ignore", category=FutureWarning) +def pytest_configure(config): + config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources") + + def pytest_addoption(parser): from diffusers.utils.testing_utils import pytest_addoption_shared diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py index 336ac2246f..95f1e137e9 100644 --- a/tests/lora/test_lora_layers_flux.py +++ b/tests/lora/test_lora_layers_flux.py @@ -20,7 +20,6 @@ import tempfile import unittest import numpy as np -import pytest import safetensors.torch import torch from parameterized import parameterized @@ -813,7 +812,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class FluxLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on audace. @@ -960,7 +958,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class FluxControlLoRAIntegrationTests(unittest.TestCase): num_inference_steps = 10 seed = 0 diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py index 19e31f320d..4cbd6523e7 100644 --- a/tests/lora/test_lora_layers_hunyuanvideo.py +++ b/tests/lora/test_lora_layers_hunyuanvideo.py @@ -17,7 +17,6 @@ import sys import unittest import numpy as np -import pytest import torch from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast @@ -198,7 +197,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class HunyuanVideoLoRAIntegrationTests(unittest.TestCase): """internal note: The integration slices were obtained on DGX. diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 8a8f2a676d..8928ccbac2 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -17,7 +17,6 @@ import sys import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -139,7 +138,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): @require_torch_accelerator @require_peft_backend @require_big_accelerator -@pytest.mark.big_accelerator class SD3LoraIntegrationTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py index c725589781..a6cb558513 100644 --- a/tests/pipelines/cogvideo/test_cogvideox.py +++ b/tests/pipelines/cogvideo/test_cogvideox.py @@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import ( from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, check_qkv_fusion_matches_attn_procs_length, @@ -45,7 +46,11 @@ enable_full_determinism() class CogVideoXPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = CogVideoXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 5ee94b09ba..5b336edc7a 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -17,7 +17,6 @@ import gc import unittest import numpy as np -import pytest import torch from huggingface_hub import hf_hub_download from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast @@ -211,7 +210,6 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl @nightly @require_big_accelerator -@pytest.mark.big_accelerator class FluxControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = FluxControlNetPipeline diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py index 712c26b0a2..1f1f800bcf 100644 --- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py +++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py @@ -18,7 +18,6 @@ import unittest from typing import Optional import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -221,7 +220,6 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3ControlNetPipeline diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index cbdf617d71..0df0e028ff 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -2,7 +2,6 @@ import gc import unittest import numpy as np -import pytest import torch from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel @@ -25,6 +24,7 @@ from diffusers.utils.testing_utils import ( from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, FluxIPAdapterTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, @@ -34,11 +34,12 @@ from ..test_pipelines_common import ( class FluxPipelineFastTests( - unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = FluxPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) @@ -224,7 +225,6 @@ class FluxPipelineFastTests( @nightly @require_big_accelerator -@pytest.mark.big_accelerator class FluxPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-schnell" @@ -312,7 +312,6 @@ class FluxPipelineSlowTests(unittest.TestCase): @slow @require_big_accelerator -@pytest.mark.big_accelerator class FluxIPAdapterPipelineSlowTests(unittest.TestCase): pipeline_class = FluxPipeline repo_id = "black-forest-labs/FLUX.1-dev" diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py index b8f36dfd3c..b73050a64d 100644 --- a/tests/pipelines/flux/test_pipeline_flux_redux.py +++ b/tests/pipelines/flux/test_pipeline_flux_redux.py @@ -2,7 +2,6 @@ import gc import unittest import numpy as np -import pytest import torch from diffusers import FluxPipeline, FluxPriorReduxPipeline @@ -19,7 +18,6 @@ from diffusers.utils.testing_utils import ( @slow @require_big_accelerator -@pytest.mark.big_accelerator class FluxReduxSlowTests(unittest.TestCase): pipeline_class = FluxPriorReduxPipeline repo_id = "black-forest-labs/FLUX.1-Redux-dev" diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py index ecc5eba964..10101af75c 100644 --- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py +++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py @@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import ( from ..test_pipelines_common import ( FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, to_np, @@ -43,7 +44,11 @@ enable_full_determinism() class HunyuanVideoPipelineFastTests( - PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase + PipelineTesterMixin, + PyramidAttentionBroadcastTesterMixin, + FasterCacheTesterMixin, + FirstBlockCacheTesterMixin, + unittest.TestCase, ): pipeline_class = HunyuanVideoPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py index 1d1eb08234..bf0c7fde59 100644 --- a/tests/pipelines/ltx/test_ltx.py +++ b/tests/pipelines/ltx/test_ltx.py @@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT from diffusers.utils.testing_utils import enable_full_determinism, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, to_np +from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase): pipeline_class = LTXPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): test_layerwise_casting = True test_group_offloading = True - def get_dummy_components(self): + def get_dummy_components(self, num_layers: int = 1): torch.manual_seed(0) transformer = LTXVideoTransformer3DModel( in_channels=8, @@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): num_attention_heads=4, attention_head_dim=8, cross_attention_dim=32, - num_layers=1, + num_layers=num_layers, caption_channels=32, ) diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py index 5b00261b06..f1684cce72 100644 --- a/tests/pipelines/mochi/test_mochi.py +++ b/tests/pipelines/mochi/test_mochi.py @@ -17,7 +17,6 @@ import inspect import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, T5EncoderModel @@ -33,13 +32,15 @@ from diffusers.utils.testing_utils import ( ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np +from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np enable_full_determinism() -class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase): +class MochiPipelineFastTests( + PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase +): pipeline_class = MochiPipeline params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -268,7 +269,6 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte @nightly @require_torch_accelerator @require_big_accelerator -@pytest.mark.big_accelerator class MochiPipelineIntegrationTests(unittest.TestCase): prompt = "A painting of a squirrel eating a burger." diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py index 577ac4ebdd..2179ec8e22 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py @@ -2,7 +2,6 @@ import gc import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -233,7 +232,6 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin): @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3PipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Pipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py index f5b5e63a81..7f913cb63d 100644 --- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py +++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py @@ -3,7 +3,6 @@ import random import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel @@ -168,7 +167,6 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte @slow @require_big_accelerator -@pytest.mark.big_accelerator class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase): pipeline_class = StableDiffusion3Img2ImgPipeline repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f87778b260..13c25ccaa4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -33,6 +33,7 @@ from diffusers import ( ) from diffusers.hooks import apply_group_offloading from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook +from diffusers.hooks.first_block_cache import FirstBlockCacheConfig from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin @@ -2648,7 +2649,7 @@ class FasterCacheTesterMixin: self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep pipe = create_pipe() pipe.transformer.enable_cache(self.faster_cache_config) - output = run_forward(pipe).flatten().flatten() + output = run_forward(pipe).flatten() image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:])) # Run inference with FasterCache disabled @@ -2755,6 +2756,55 @@ class FasterCacheTesterMixin: self.assertTrue(state.cache is None, "Cache should be reset to None.") +# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out +# of the box once there is better cache support/implementation +class FirstBlockCacheTesterMixin: + # threshold is intentionally set higher than usual values since we're testing with random unconverged models + # that will not satisfy the expected properties of the denoiser for caching to be effective + first_block_cache_config = FirstBlockCacheConfig(threshold=0.8) + + def test_first_block_cache_inference(self, expected_atol: float = 0.1): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + def create_pipe(): + torch.manual_seed(0) + num_layers = 2 + components = self.get_dummy_components(num_layers=num_layers) + pipe = self.pipeline_class(**components) + pipe = pipe.to(device) + pipe.set_progress_bar_config(disable=None) + return pipe + + def run_forward(pipe): + torch.manual_seed(0) + inputs = self.get_dummy_inputs(device) + inputs["num_inference_steps"] = 4 + return pipe(**inputs)[0] + + # Run inference without FirstBlockCache + pipe = create_pipe() + output = run_forward(pipe).flatten() + original_image_slice = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache enabled + pipe = create_pipe() + pipe.transformer.enable_cache(self.first_block_cache_config) + output = run_forward(pipe).flatten() + image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:])) + + # Run inference with FirstBlockCache disabled + pipe.transformer.disable_cache() + output = run_forward(pipe).flatten() + image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:])) + + assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), ( + "FirstBlockCache outputs should not differ much." + ) + assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), ( + "Outputs from normal inference and after disabling cache should not differ." + ) + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 06116cac3a..98005cfbc8 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -872,6 +872,7 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests): @require_torch_version_greater("2.7.1") +@require_bitsandbytes_version_greater("0.45.5") class Bnb4BitCompileTests(QuantCompileTests): @property def quantization_config(self): diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 2ea4cdfde8..f3bbc34e8b 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -837,6 +837,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests): @require_torch_version_greater_equal("2.6.0") +@require_bitsandbytes_version_greater("0.45.5") class Bnb8BitCompileTests(QuantCompileTests): @property def quantization_config(self):