diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 696097fd54..747e1d8154 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -11,17 +11,18 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+ BASE_PATH: benchmark_outputs
jobs:
- torch_pipelines_cuda_benchmark_tests:
+ torch_models_cuda_benchmark_tests:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
- name: Torch Core Pipelines CUDA Benchmarking Tests
+ name: Torch Core Models CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on:
- group: aws-g6-4xlarge-plus
+ group: aws-g6e-4xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
@@ -35,27 +36,47 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
+ apt update
+ apt install -y libpq-dev postgresql-client
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
- python -m uv pip install pandas peft
- python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
+ python -m uv pip install -r benchmarks/requirements.txt
- name: Environment
run: |
python utils/print_env.py
- name: Diffusers Benchmarking
env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
- BASE_PATH: benchmark_outputs
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
- cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
+ cd benchmarks && python run_all.py
+
+ - name: Push results to the Hub
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
+ run: |
+ cd benchmarks && python push_results.py
+ mkdir $BASE_PATH && cp *.csv $BASE_PATH
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
- path: benchmarks/benchmark_outputs
+ path: benchmarks/${{ env.BASE_PATH }}
+
+ # TODO: enable this once the connection problem has been resolved.
+ - name: Update benchmarking results to DB
+ env:
+ PGDATABASE: metrics
+ PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
+ PGUSER: transformers_benchmarks
+ PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+ run: |
+ git config --global --add safe.directory /__w/diffusers/diffusers
+ commit_id=$GITHUB_SHA
+ commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
+ cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index dca7ca5820..549dff2c21 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -255,7 +255,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -m "big_gpu_with_torch_cuda" \
+ -m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index f7b21e9b9a..034f7c3c63 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -188,7 +188,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
- max-parallel: 2
+ max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 0000000000..574779bb50
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,69 @@
+# Diffusers Benchmarks
+
+Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
+
+* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
+* Base + `torch.compile()`
+* NF4 quantization
+* Layerwise upcasting
+
+Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
+
+The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
+
+The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
+
+## Running the benchmarks manually
+
+First set up `torch` and install `diffusers` from the root of the directory:
+
+```py
+pip install -e ".[quality,test]"
+```
+
+Then make sure the other dependencies are installed:
+
+```sh
+cd benchmarks/
+pip install -r requirements.txt
+```
+
+We need to be authenticated to access some of the checkpoints used during benchmarking:
+
+```sh
+huggingface-cli login
+```
+
+We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
+
+Then you can either launch the entire benchmarking suite by running:
+
+```sh
+python run_all.py
+```
+
+Or, you can run the individual benchmarks.
+
+## Customizing the benchmarks
+
+We define "scenarios" to cover the most common ways in which these models are used. You can
+define a new scenario, modifying an existing benchmark file:
+
+```py
+BenchmarkScenario(
+ name=f"{CKPT_ID}-bnb-8bit",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ "quantization_config": BitsAndBytesConfig(load_in_8bit=True),
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+)
+```
+
+You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
+
+Happy benchmarking 🧨
\ No newline at end of file
diff --git a/tests/pipelines/amused/__init__.py b/benchmarks/__init__.py
similarity index 100%
rename from tests/pipelines/amused/__init__.py
rename to benchmarks/__init__.py
diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py
deleted file mode 100644
index 45bf65c93c..0000000000
--- a/benchmarks/base_classes.py
+++ /dev/null
@@ -1,346 +0,0 @@
-import os
-import sys
-
-import torch
-
-from diffusers import (
- AutoPipelineForImage2Image,
- AutoPipelineForInpainting,
- AutoPipelineForText2Image,
- ControlNetModel,
- LCMScheduler,
- StableDiffusionAdapterPipeline,
- StableDiffusionControlNetPipeline,
- StableDiffusionXLAdapterPipeline,
- StableDiffusionXLControlNetPipeline,
- T2IAdapter,
- WuerstchenCombinedPipeline,
-)
-from diffusers.utils import load_image
-
-
-sys.path.append(".")
-
-from utils import ( # noqa: E402
- BASE_PATH,
- PROMPT,
- BenchmarkInfo,
- benchmark_fn,
- bytes_to_giga_bytes,
- flush,
- generate_csv_dict,
- write_to_csv,
-)
-
-
-RESOLUTION_MAPPING = {
- "Lykon/DreamShaper": (512, 512),
- "lllyasviel/sd-controlnet-canny": (512, 512),
- "diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
- "TencentARC/t2iadapter_canny_sd14v1": (512, 512),
- "TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
- "stabilityai/stable-diffusion-2-1": (768, 768),
- "stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
- "stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
- "stabilityai/sdxl-turbo": (512, 512),
-}
-
-
-class BaseBenchmak:
- pipeline_class = None
-
- def __init__(self, args):
- super().__init__()
-
- def run_inference(self, args):
- raise NotImplementedError
-
- def benchmark(self, args):
- raise NotImplementedError
-
- def get_result_filepath(self, args):
- pipeline_class_name = str(self.pipe.__class__.__name__)
- name = (
- args.ckpt.replace("/", "_")
- + "_"
- + pipeline_class_name
- + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
- )
- filepath = os.path.join(BASE_PATH, name)
- return filepath
-
-
-class TextToImageBenchmark(BaseBenchmak):
- pipeline_class = AutoPipelineForText2Image
-
- def __init__(self, args):
- pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- if args.run_compile:
- if not isinstance(pipe, WuerstchenCombinedPipeline):
- pipe.unet.to(memory_format=torch.channels_last)
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
- pipe.movq.to(memory_format=torch.channels_last)
- pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
- else:
- print("Run torch compile")
- pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
- pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
- def benchmark(self, args):
- flush()
-
- print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
-
- time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
- memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
- benchmark_info = BenchmarkInfo(time=time, memory=memory)
-
- pipeline_class_name = str(self.pipe.__class__.__name__)
- flush()
- csv_dict = generate_csv_dict(
- pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
- )
- filepath = self.get_result_filepath(args)
- write_to_csv(filepath, csv_dict)
- print(f"Logs written to: {filepath}")
- flush()
-
-
-class TurboTextToImageBenchmark(TextToImageBenchmark):
- def __init__(self, args):
- super().__init__(args)
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=0.0,
- )
-
-
-class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
- lora_id = "latent-consistency/lcm-lora-sdxl"
-
- def __init__(self, args):
- super().__init__(args)
- self.pipe.load_lora_weights(self.lora_id)
- self.pipe.fuse_lora()
- self.pipe.unload_lora_weights()
- self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
-
- def get_result_filepath(self, args):
- pipeline_class_name = str(self.pipe.__class__.__name__)
- name = (
- self.lora_id.replace("/", "_")
- + "_"
- + pipeline_class_name
- + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
- )
- filepath = os.path.join(BASE_PATH, name)
- return filepath
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=1.0,
- )
-
- def benchmark(self, args):
- flush()
-
- print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
-
- time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
- memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
- benchmark_info = BenchmarkInfo(time=time, memory=memory)
-
- pipeline_class_name = str(self.pipe.__class__.__name__)
- flush()
- csv_dict = generate_csv_dict(
- pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
- )
- filepath = self.get_result_filepath(args)
- write_to_csv(filepath, csv_dict)
- print(f"Logs written to: {filepath}")
- flush()
-
-
-class ImageToImageBenchmark(TextToImageBenchmark):
- pipeline_class = AutoPipelineForImage2Image
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
- image = load_image(url).convert("RGB")
-
- def __init__(self, args):
- super().__init__(args)
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class TurboImageToImageBenchmark(ImageToImageBenchmark):
- def __init__(self, args):
- super().__init__(args)
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=0.0,
- strength=0.5,
- )
-
-
-class InpaintingBenchmark(ImageToImageBenchmark):
- pipeline_class = AutoPipelineForInpainting
- mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
- mask = load_image(mask_url).convert("RGB")
-
- def __init__(self, args):
- super().__init__(args)
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
- self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- mask_image=self.mask,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
- url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
- image = load_image(url)
-
- def __init__(self, args):
- pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
- pipe.load_ip_adapter(
- args.ip_adapter_id[0],
- subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
- weight_name=args.ip_adapter_id[1],
- )
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- ip_adapter_image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class ControlNetBenchmark(TextToImageBenchmark):
- pipeline_class = StableDiffusionControlNetPipeline
- aux_network_class = ControlNetModel
- root_ckpt = "Lykon/DreamShaper"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
- image = load_image(url).convert("RGB")
-
- def __init__(self, args):
- aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.controlnet.to(memory_format=torch.channels_last)
-
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
-
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class ControlNetSDXLBenchmark(ControlNetBenchmark):
- pipeline_class = StableDiffusionXLControlNetPipeline
- root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
-
- def __init__(self, args):
- super().__init__(args)
-
-
-class T2IAdapterBenchmark(ControlNetBenchmark):
- pipeline_class = StableDiffusionAdapterPipeline
- aux_network_class = T2IAdapter
- root_ckpt = "Lykon/DreamShaper"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
- image = load_image(url).convert("L")
-
- def __init__(self, args):
- aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.adapter.to(memory_format=torch.channels_last)
-
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
-
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
-
-class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
- pipeline_class = StableDiffusionXLAdapterPipeline
- root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
- image = load_image(url)
-
- def __init__(self, args):
- super().__init__(args)
diff --git a/benchmarks/benchmark_controlnet.py b/benchmarks/benchmark_controlnet.py
deleted file mode 100644
index 9217004461..0000000000
--- a/benchmarks/benchmark_controlnet.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="lllyasviel/sd-controlnet-canny",
- choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = (
- ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
- )
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_ip_adapters.py b/benchmarks/benchmark_ip_adapters.py
deleted file mode 100644
index 9a31a21fc6..0000000000
--- a/benchmarks/benchmark_ip_adapters.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
-
-
-IP_ADAPTER_CKPTS = {
- # because original SD v1.5 has been taken down.
- "Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
- "stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
-}
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="rstabilityai/stable-diffusion-xl-base-1.0",
- choices=list(IP_ADAPTER_CKPTS.keys()),
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
- benchmark_pipe = IPAdapterTextToImageBenchmark(args)
- args.ckpt = f"{args.ckpt} (IP-Adapter)"
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_sd_img.py b/benchmarks/benchmark_sd_img.py
deleted file mode 100644
index 772befe879..0000000000
--- a/benchmarks/benchmark_sd_img.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=[
- "Lykon/DreamShaper",
- "stabilityai/stable-diffusion-2-1",
- "stabilityai/stable-diffusion-xl-refiner-1.0",
- "stabilityai/sdxl-turbo",
- ],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_sd_inpainting.py b/benchmarks/benchmark_sd_inpainting.py
deleted file mode 100644
index 143adcb0d8..0000000000
--- a/benchmarks/benchmark_sd_inpainting.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import InpaintingBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=[
- "Lykon/DreamShaper",
- "stabilityai/stable-diffusion-2-1",
- "stabilityai/stable-diffusion-xl-base-1.0",
- ],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = InpaintingBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_t2i_adapter.py b/benchmarks/benchmark_t2i_adapter.py
deleted file mode 100644
index 44b04b470e..0000000000
--- a/benchmarks/benchmark_t2i_adapter.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="TencentARC/t2iadapter_canny_sd14v1",
- choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = (
- T2IAdapterBenchmark(args)
- if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
- else T2IAdapterSDXLBenchmark(args)
- )
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_t2i_lcm_lora.py b/benchmarks/benchmark_t2i_lcm_lora.py
deleted file mode 100644
index 957e0a463e..0000000000
--- a/benchmarks/benchmark_t2i_lcm_lora.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="stabilityai/stable-diffusion-xl-base-1.0",
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=4)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = LCMLoRATextToImageBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_text_to_image.py b/benchmarks/benchmark_text_to_image.py
deleted file mode 100644
index ddc7fb2676..0000000000
--- a/benchmarks/benchmark_text_to_image.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
-
-
-ALL_T2I_CKPTS = [
- "Lykon/DreamShaper",
- "segmind/SSD-1B",
- "stabilityai/stable-diffusion-xl-base-1.0",
- "kandinsky-community/kandinsky-2-2-decoder",
- "warp-ai/wuerstchen",
- "stabilityai/sdxl-turbo",
-]
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=ALL_T2I_CKPTS,
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_cls = None
- if "turbo" in args.ckpt:
- benchmark_cls = TurboTextToImageBenchmark
- else:
- benchmark_cls = TextToImageBenchmark
-
- benchmark_pipe = benchmark_cls(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py
new file mode 100644
index 0000000000..18a2680052
--- /dev/null
+++ b/benchmarks/benchmarking_flux.py
@@ -0,0 +1,98 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "black-forest-labs/FLUX.1-dev"
+RESULT_FILENAME = "flux.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # resolution: 1024x1024
+ # maximum sequence length 512
+ hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
+ pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
+ image_ids = torch.ones(512, 3, **device_dtype_kwargs)
+ text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ guidance = torch.tensor([1.0], **device_dtype_kwargs)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": timestep,
+ "guidance": guidance,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bnb-nf4",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ "quantization_config": BitsAndBytesConfig(
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
+ ),
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_ltx.py b/benchmarks/benchmarking_ltx.py
new file mode 100644
index 0000000000..3d698fd0bd
--- /dev/null
+++ b/benchmarks/benchmarking_ltx.py
@@ -0,0 +1,80 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import LTXVideoTransformer3DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
+RESULT_FILENAME = "ltx.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # 512x704 (161 frames)
+ # `max_sequence_length`: 256
+ hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
+ encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "timestep": timestep,
+ "video_coords": video_coords,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_sdxl.py b/benchmarks/benchmarking_sdxl.py
new file mode 100644
index 0000000000..ded62784f2
--- /dev/null
+++ b/benchmarks/benchmarking_sdxl.py
@@ -0,0 +1,82 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import UNet2DConditionModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
+RESULT_FILENAME = "sdxl.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # height: 1024
+ # width: 1024
+ # max_sequence_length: 77
+ hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ added_cond_kwargs = {
+ "text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
+ "time_ids": torch.ones(1, 6, **device_dtype_kwargs),
+ }
+
+ return {
+ "sample": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "added_cond_kwargs": added_cond_kwargs,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py
new file mode 100644
index 0000000000..c8c1a10ef8
--- /dev/null
+++ b/benchmarks/benchmarking_utils.py
@@ -0,0 +1,244 @@
+import gc
+import inspect
+import logging
+import os
+import queue
+import threading
+from contextlib import nullcontext
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Optional, Union
+
+import pandas as pd
+import torch
+import torch.utils.benchmark as benchmark
+
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.testing_utils import require_torch_gpu, torch_device
+
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+NUM_WARMUP_ROUNDS = 5
+
+
+def benchmark_fn(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)",
+ globals={"args": args, "kwargs": kwargs, "f": f},
+ num_threads=1,
+ )
+ return float(f"{(t0.blocked_autorange().mean):.3f}")
+
+
+def flush():
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
+
+
+# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
+def calculate_flops(model, input_dict):
+ try:
+ from torchprofile import profile_macs
+ except ModuleNotFoundError:
+ raise
+
+ # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
+ sig = inspect.signature(model.forward)
+ param_names = [
+ p.name
+ for p in sig.parameters.values()
+ if p.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and p.name != "self"
+ ]
+ bound = sig.bind_partial(**input_dict)
+ bound.apply_defaults()
+ args = tuple(bound.arguments[name] for name in param_names)
+
+ model.eval()
+ with torch.no_grad():
+ macs = profile_macs(model, args)
+ flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
+ return flops
+
+
+def calculate_params(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+# Users can define their own in case this doesn't suffice. For most cases,
+# it should be sufficient.
+def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
+ model = model_cls.from_pretrained(**init_kwargs).eval()
+ if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
+ model.enable_group_offload(**group_offload_kwargs)
+ else:
+ model.to(torch_device)
+ if layerwise_upcasting:
+ model.enable_layerwise_casting(
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
+ )
+ return model
+
+
+@dataclass
+class BenchmarkScenario:
+ name: str
+ model_cls: ModelMixin
+ model_init_kwargs: Dict[str, Any]
+ model_init_fn: Callable
+ get_model_input_dict: Callable
+ compile_kwargs: Optional[Dict[str, Any]] = None
+
+
+@require_torch_gpu
+class BenchmarkMixin:
+ def pre_benchmark(self):
+ flush()
+ torch.compiler.reset()
+
+ def post_benchmark(self, model):
+ model.cpu()
+ flush()
+ torch.compiler.reset()
+
+ @torch.no_grad()
+ def run_benchmark(self, scenario: BenchmarkScenario):
+ # 0) Basic stats
+ logger.info(f"Running scenario: {scenario.name}.")
+ try:
+ model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
+ num_params = round(calculate_params(model) / 1e9, 2)
+ try:
+ flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
+ except Exception as e:
+ logger.info(f"Problem in calculating FLOPs:\n{e}")
+ flops = None
+ model.cpu()
+ del model
+ except Exception as e:
+ logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
+ return {}
+ self.pre_benchmark()
+
+ # 1) plain stats
+ results = {}
+ plain = None
+ try:
+ plain = self._run_phase(
+ model_cls=scenario.model_cls,
+ init_fn=scenario.model_init_fn,
+ init_kwargs=scenario.model_init_kwargs,
+ get_input_fn=scenario.get_model_input_dict,
+ compile_kwargs=None,
+ )
+ except Exception as e:
+ logger.info(f"Benchmark could not be run with the following error:\n{e}")
+ return results
+
+ # 2) compiled stats (if any)
+ compiled = {"time": None, "memory": None}
+ if scenario.compile_kwargs:
+ try:
+ compiled = self._run_phase(
+ model_cls=scenario.model_cls,
+ init_fn=scenario.model_init_fn,
+ init_kwargs=scenario.model_init_kwargs,
+ get_input_fn=scenario.get_model_input_dict,
+ compile_kwargs=scenario.compile_kwargs,
+ )
+ except Exception as e:
+ logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
+ if plain is None:
+ return results
+
+ # 3) merge
+ result = {
+ "scenario": scenario.name,
+ "model_cls": scenario.model_cls.__name__,
+ "num_params_B": num_params,
+ "flops_G": flops,
+ "time_plain_s": plain["time"],
+ "mem_plain_GB": plain["memory"],
+ "time_compile_s": compiled["time"],
+ "mem_compile_GB": compiled["memory"],
+ }
+ if scenario.compile_kwargs:
+ result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
+ result["mode"] = scenario.compile_kwargs.get("mode", "default")
+ else:
+ result["fullgraph"], result["mode"] = None, None
+ return result
+
+ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
+ if not isinstance(scenarios, list):
+ scenarios = [scenarios]
+ record_queue = queue.Queue()
+ stop_signal = object()
+
+ def _writer_thread():
+ while True:
+ item = record_queue.get()
+ if item is stop_signal:
+ break
+ df_row = pd.DataFrame([item])
+ write_header = not os.path.exists(filename)
+ df_row.to_csv(filename, mode="a", header=write_header, index=False)
+ record_queue.task_done()
+
+ record_queue.task_done()
+
+ writer = threading.Thread(target=_writer_thread, daemon=True)
+ writer.start()
+
+ for s in scenarios:
+ try:
+ record = self.run_benchmark(s)
+ if record:
+ record_queue.put(record)
+ else:
+ logger.info(f"Record empty from scenario: {s.name}.")
+ except Exception as e:
+ logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
+ record_queue.put(stop_signal)
+ logger.info(f"Results serialized to {filename=}.")
+
+ def _run_phase(
+ self,
+ *,
+ model_cls: ModelMixin,
+ init_fn: Callable,
+ init_kwargs: Dict[str, Any],
+ get_input_fn: Callable,
+ compile_kwargs: Optional[Dict[str, Any]],
+ ) -> Dict[str, float]:
+ # setup
+ self.pre_benchmark()
+
+ # init & (optional) compile
+ model = init_fn(model_cls, **init_kwargs)
+ if compile_kwargs:
+ model.compile(**compile_kwargs)
+
+ # build inputs
+ inp = get_input_fn()
+
+ # measure
+ run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
+ with run_ctx:
+ for _ in range(NUM_WARMUP_ROUNDS):
+ _ = model(**inp)
+ time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
+ mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
+ mem_gb = round(mem_gb, 2)
+
+ # teardown
+ self.post_benchmark(model)
+ del model
+ return {"time": time_s, "memory": mem_gb}
diff --git a/benchmarks/benchmarking_wan.py b/benchmarks/benchmarking_wan.py
new file mode 100644
index 0000000000..64e81fdb6b
--- /dev/null
+++ b/benchmarks/benchmarking_wan.py
@@ -0,0 +1,74 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import WanTransformer3DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
+RESULT_FILENAME = "wan.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # height: 480
+ # width: 832
+ # num_frames: 81
+ # max_sequence_length: 512
+ hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+
+ return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py
new file mode 100644
index 0000000000..55e46b0586
--- /dev/null
+++ b/benchmarks/populate_into_db.py
@@ -0,0 +1,166 @@
+import argparse
+import os
+import sys
+
+import gpustat
+import pandas as pd
+import psycopg2
+import psycopg2.extras
+from psycopg2.extensions import register_adapter
+from psycopg2.extras import Json
+
+
+register_adapter(dict, Json)
+
+FINAL_CSV_FILENAME = "collated_results.csv"
+# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
+BENCHMARKS_TABLE_NAME = "benchmarks"
+MEASUREMENTS_TABLE_NAME = "model_measurements"
+
+
+def _init_benchmark(conn, branch, commit_id, commit_msg):
+ gpu_stats = gpustat.GPUStatCollection.new_query()
+ metadata = {"gpu_name": gpu_stats[0]["name"]}
+ repository = "huggingface/diffusers"
+ with conn.cursor() as cur:
+ cur.execute(
+ f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
+ (repository, branch, commit_id, commit_msg, metadata),
+ )
+ benchmark_id = cur.fetchone()[0]
+ print(f"Initialised benchmark #{benchmark_id}")
+ return benchmark_id
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "branch",
+ type=str,
+ help="The branch name on which the benchmarking is performed.",
+ )
+
+ parser.add_argument(
+ "commit_id",
+ type=str,
+ help="The commit hash on which the benchmarking is performed.",
+ )
+
+ parser.add_argument(
+ "commit_msg",
+ type=str,
+ help="The commit message associated with the commit, truncated to 70 characters.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ try:
+ conn = psycopg2.connect(
+ host=os.getenv("PGHOST"),
+ database=os.getenv("PGDATABASE"),
+ user=os.getenv("PGUSER"),
+ password=os.getenv("PGPASSWORD"),
+ )
+ print("DB connection established successfully.")
+ except Exception as e:
+ print(f"Problem during DB init: {e}")
+ sys.exit(1)
+
+ try:
+ benchmark_id = _init_benchmark(
+ conn=conn,
+ branch=args.branch,
+ commit_id=args.commit_id,
+ commit_msg=args.commit_msg,
+ )
+ except Exception as e:
+ print(f"Problem during initializing benchmark: {e}")
+ sys.exit(1)
+
+ cur = conn.cursor()
+
+ df = pd.read_csv(FINAL_CSV_FILENAME)
+
+ # Helper to cast values (or None) given a dtype
+ def _cast_value(val, dtype: str):
+ if pd.isna(val):
+ return None
+
+ if dtype == "text":
+ return str(val).strip()
+
+ if dtype == "float":
+ try:
+ return float(val)
+ except ValueError:
+ return None
+
+ if dtype == "bool":
+ s = str(val).strip().lower()
+ if s in ("true", "t", "yes", "1"):
+ return True
+ if s in ("false", "f", "no", "0"):
+ return False
+ if val in (1, 1.0):
+ return True
+ if val in (0, 0.0):
+ return False
+ return None
+
+ return val
+
+ try:
+ rows_to_insert = []
+ for _, row in df.iterrows():
+ scenario = _cast_value(row.get("scenario"), "text")
+ model_cls = _cast_value(row.get("model_cls"), "text")
+ num_params_B = _cast_value(row.get("num_params_B"), "float")
+ flops_G = _cast_value(row.get("flops_G"), "float")
+ time_plain_s = _cast_value(row.get("time_plain_s"), "float")
+ mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
+ time_compile_s = _cast_value(row.get("time_compile_s"), "float")
+ mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
+ fullgraph = _cast_value(row.get("fullgraph"), "bool")
+ mode = _cast_value(row.get("mode"), "text")
+
+ # If "github_sha" column exists in the CSV, cast it; else default to None
+ if "github_sha" in df.columns:
+ github_sha = _cast_value(row.get("github_sha"), "text")
+ else:
+ github_sha = None
+
+ measurements = {
+ "scenario": scenario,
+ "model_cls": model_cls,
+ "num_params_B": num_params_B,
+ "flops_G": flops_G,
+ "time_plain_s": time_plain_s,
+ "mem_plain_GB": mem_plain_GB,
+ "time_compile_s": time_compile_s,
+ "mem_compile_GB": mem_compile_GB,
+ "fullgraph": fullgraph,
+ "mode": mode,
+ "github_sha": github_sha,
+ }
+ rows_to_insert.append((benchmark_id, measurements))
+
+ # Batch-insert all rows
+ insert_sql = f"""
+ INSERT INTO {MEASUREMENTS_TABLE_NAME} (
+ benchmark_id,
+ measurements
+ )
+ VALUES (%s, %s);
+ """
+
+ psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
+ conn.commit()
+
+ cur.close()
+ conn.close()
+ except Exception as e:
+ print(f"Exception: {e}")
+ sys.exit(1)
diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py
index 71cd60f32c..8be3b39368 100644
--- a/benchmarks/push_results.py
+++ b/benchmarks/push_results.py
@@ -1,19 +1,19 @@
-import glob
-import sys
+import os
import pandas as pd
from huggingface_hub import hf_hub_download, upload_file
from huggingface_hub.utils import EntryNotFoundError
-sys.path.append(".")
-from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
+REPO_ID = "diffusers/benchmarks"
def has_previous_benchmark() -> str:
+ from run_all import FINAL_CSV_FILENAME
+
csv_path = None
try:
- csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
+ csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
except EntryNotFoundError:
csv_path = None
return csv_path
@@ -26,46 +26,50 @@ def filter_float(value):
def push_to_hf_dataset():
- all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
- collate_csv(all_csvs, FINAL_CSV_FILE)
+ from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
- # If there's an existing benchmark file, we should report the changes.
csv_path = has_previous_benchmark()
if csv_path is not None:
- current_results = pd.read_csv(FINAL_CSV_FILE)
+ current_results = pd.read_csv(FINAL_CSV_FILENAME)
previous_results = pd.read_csv(csv_path)
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
- numeric_columns = [
- c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
- ]
for column in numeric_columns:
- previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
+ # get previous values as floats, aligned to current index
+ prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
- # Calculate the percentage change
- current_results[column] = current_results[column].astype(float)
- previous_results[column] = previous_results[column].astype(float)
- percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
+ # get current values as floats
+ curr_vals = current_results[column].astype(float)
- # Format the values with '+' or '-' sign and append to original values
- current_results[column] = current_results[column].map(str) + percent_change.map(
- lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
+ # stringify the current values
+ curr_str = curr_vals.map(str)
+
+ # build an appendage only when prev exists and differs
+ append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
+ lambda x: f" ({x})" if pd.notnull(x) else ""
)
- # There might be newly added rows. So, filter out the NaNs.
- current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
- # Overwrite the current result file.
- current_results.to_csv(FINAL_CSV_FILE, index=False)
+ # combine
+ current_results[column] = curr_str + append_str
+ os.remove(FINAL_CSV_FILENAME)
+ current_results.to_csv(FINAL_CSV_FILENAME, index=False)
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
upload_file(
repo_id=REPO_ID,
- path_in_repo=FINAL_CSV_FILE,
- path_or_fileobj=FINAL_CSV_FILE,
+ path_in_repo=FINAL_CSV_FILENAME,
+ path_or_fileobj=FINAL_CSV_FILENAME,
repo_type="dataset",
commit_message=commit_message,
)
+ upload_file(
+ repo_id="diffusers/benchmark-analyzer",
+ path_in_repo=FINAL_CSV_FILENAME,
+ path_or_fileobj=FINAL_CSV_FILENAME,
+ repo_type="space",
+ commit_message=commit_message,
+ )
if __name__ == "__main__":
diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt
new file mode 100644
index 0000000000..1f47ecc6ca
--- /dev/null
+++ b/benchmarks/requirements.txt
@@ -0,0 +1,6 @@
+pandas
+psutil
+gpustat
+torchprofile
+bitsandbytes
+psycopg2==2.9.9
\ No newline at end of file
diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py
index c9932cc71c..9cf053f548 100644
--- a/benchmarks/run_all.py
+++ b/benchmarks/run_all.py
@@ -1,101 +1,84 @@
import glob
+import logging
+import os
import subprocess
-import sys
-from typing import List
+
+import pandas as pd
-sys.path.append(".")
-from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
-
-PATTERN = "benchmark_*.py"
+PATTERN = "benchmarking_*.py"
+FINAL_CSV_FILENAME = "collated_results.csv"
+GITHUB_SHA = os.getenv("GITHUB_SHA", None)
class SubprocessCallException(Exception):
pass
-# Taken from `test_examples_utils.py`
-def run_command(command: List[str], return_stdout=False):
- """
- Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
- if an error occurred while running `command`
- """
+def run_command(command: list[str], return_stdout=False):
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
- if return_stdout:
- if hasattr(output, "decode"):
- output = output.decode("utf-8")
- return output
+ if return_stdout and hasattr(output, "decode"):
+ return output.decode("utf-8")
except subprocess.CalledProcessError as e:
- raise SubprocessCallException(
- f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
- ) from e
+ raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
-def main():
- python_files = glob.glob(PATTERN)
+def merge_csvs(final_csv: str = "collated_results.csv"):
+ all_csvs = glob.glob("*.csv")
+ all_csvs = [f for f in all_csvs if f != final_csv]
+ if not all_csvs:
+ logger.info("No result CSVs found to merge.")
+ return
- for file in python_files:
- print(f"****** Running file: {file} ******")
-
- # Run with canonical settings.
- if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
- command = f"python {file}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
-
- # Run variants.
- for file in python_files:
- # See: https://github.com/pytorch/pytorch/issues/129637
- if file == "benchmark_ip_adapters.py":
+ df_list = []
+ for f in all_csvs:
+ try:
+ d = pd.read_csv(f)
+ except pd.errors.EmptyDataError:
+ # If a file existed but was zero‐bytes or corrupted, skip it
continue
+ df_list.append(d)
- if file == "benchmark_text_to_image.py":
- for ckpt in ALL_T2I_CKPTS:
- command = f"python {file} --ckpt {ckpt}"
+ if not df_list:
+ logger.info("All result CSVs were empty or invalid; nothing to merge.")
+ return
- if "turbo" in ckpt:
- command += " --num_inference_steps 1"
+ final_df = pd.concat(df_list, ignore_index=True)
+ if GITHUB_SHA is not None:
+ final_df["github_sha"] = GITHUB_SHA
+ final_df.to_csv(final_csv, index=False)
+ logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
- run_command(command.split())
- command += " --run_compile"
- run_command(command.split())
+def run_scripts():
+ python_files = sorted(glob.glob(PATTERN))
+ python_files = [f for f in python_files if f != "benchmarking_utils.py"]
- elif file == "benchmark_sd_img.py":
- for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
- command = f"python {file} --ckpt {ckpt}"
+ for file in python_files:
+ script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
+ logger.info(f"\n****** Running file: {file} ******")
- if ckpt == "stabilityai/sdxl-turbo":
- command += " --num_inference_steps 2"
+ partial_csv = f"{script_name}.csv"
+ if os.path.exists(partial_csv):
+ logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
+ os.remove(partial_csv)
- run_command(command.split())
- command += " --run_compile"
- run_command(command.split())
+ command = ["python", file]
+ try:
+ run_command(command)
+ logger.info(f"→ {file} finished normally.")
+ except SubprocessCallException as e:
+ logger.info(f"Error running {file}:\n{e}")
+ finally:
+ logger.info(f"→ Merging partial CSVs after {file} …")
+ merge_csvs(final_csv=FINAL_CSV_FILENAME)
- elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
- sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
- command = f"python {file} --ckpt {sdxl_ckpt}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
-
- elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
- sdxl_ckpt = (
- "diffusers/controlnet-canny-sdxl-1.0"
- if "controlnet" in file
- else "TencentARC/t2i-adapter-canny-sdxl-1.0"
- )
- command = f"python {file} --ckpt {sdxl_ckpt}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
+ logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
if __name__ == "__main__":
- main()
+ run_scripts()
diff --git a/benchmarks/utils.py b/benchmarks/utils.py
deleted file mode 100644
index 5fce920ac6..0000000000
--- a/benchmarks/utils.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import argparse
-import csv
-import gc
-import os
-from dataclasses import dataclass
-from typing import Dict, List, Union
-
-import torch
-import torch.utils.benchmark as benchmark
-
-
-GITHUB_SHA = os.getenv("GITHUB_SHA", None)
-BENCHMARK_FIELDS = [
- "pipeline_cls",
- "ckpt_id",
- "batch_size",
- "num_inference_steps",
- "model_cpu_offload",
- "run_compile",
- "time (secs)",
- "memory (gbs)",
- "actual_gpu_memory (gbs)",
- "github_sha",
-]
-
-PROMPT = "ghibli style, a fantasy landscape with castles"
-BASE_PATH = os.getenv("BASE_PATH", ".")
-TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
-
-REPO_ID = "diffusers/benchmarks"
-FINAL_CSV_FILE = "collated_results.csv"
-
-
-@dataclass
-class BenchmarkInfo:
- time: float
- memory: float
-
-
-def flush():
- """Wipes off memory."""
- gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
-
-
-def bytes_to_giga_bytes(bytes):
- return f"{(bytes / 1024 / 1024 / 1024):.3f}"
-
-
-def benchmark_fn(f, *args, **kwargs):
- t0 = benchmark.Timer(
- stmt="f(*args, **kwargs)",
- globals={"args": args, "kwargs": kwargs, "f": f},
- num_threads=torch.get_num_threads(),
- )
- return f"{(t0.blocked_autorange().mean):.3f}"
-
-
-def generate_csv_dict(
- pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
-) -> Dict[str, Union[str, bool, float]]:
- """Packs benchmarking data into a dictionary for latter serialization."""
- data_dict = {
- "pipeline_cls": pipeline_cls,
- "ckpt_id": ckpt,
- "batch_size": args.batch_size,
- "num_inference_steps": args.num_inference_steps,
- "model_cpu_offload": args.model_cpu_offload,
- "run_compile": args.run_compile,
- "time (secs)": benchmark_info.time,
- "memory (gbs)": benchmark_info.memory,
- "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
- "github_sha": GITHUB_SHA,
- }
- return data_dict
-
-
-def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
- """Serializes a dictionary into a CSV file."""
- with open(file_name, mode="w", newline="") as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
- writer.writeheader()
- writer.writerow(data_dict)
-
-
-def collate_csv(input_files: List[str], output_file: str):
- """Collates multiple identically structured CSVs into a single CSV file."""
- with open(output_file, mode="w", newline="") as outfile:
- writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
- writer.writeheader()
-
- for file in input_files:
- with open(file, mode="r") as infile:
- reader = csv.DictReader(infile)
- for row in reader:
- writer.writerow(row)
diff --git a/docker/diffusers-doc-builder/Dockerfile b/docker/diffusers-doc-builder/Dockerfile
index c9fc62707c..3a76b3331c 100644
--- a/docker/diffusers-doc-builder/Dockerfile
+++ b/docker/diffusers-doc-builder/Dockerfile
@@ -47,6 +47,10 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
tensorboard \
transformers \
matplotlib \
- setuptools==69.5.1
+ setuptools==69.5.1 \
+ bitsandbytes \
+ torchao \
+ gguf \
+ optimum-quanto
CMD ["/bin/bash"]
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 283efeef72..770093438e 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -64,6 +64,8 @@
title: Overview
- local: using-diffusers/create_a_server
title: Create a server
+ - local: using-diffusers/batched_inference
+ title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- local: using-diffusers/scheduler_features
diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md
index e90cb32c54..9ba4742085 100644
--- a/docs/source/en/api/cache.md
+++ b/docs/source/en/api/cache.md
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
+
+### FirstBlockCacheConfig
+
+[[autodoc]] FirstBlockCacheConfig
+
+[[autodoc]] apply_first_block_cache
diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md
index eb78c8b704..ad292abca2 100644
--- a/docs/source/en/api/pipelines/amused.md
+++ b/docs/source/en/api/pipelines/amused.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# aMUSEd
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
diff --git a/docs/source/en/api/pipelines/attend_and_excite.md b/docs/source/en/api/pipelines/attend_and_excite.md
index ca0aa7af98..b5ce3bb767 100644
--- a/docs/source/en/api/pipelines/attend_and_excite.md
+++ b/docs/source/en/api/pipelines/attend_and_excite.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Attend-and-Excite
Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md
index a5ef9c4872..6b143d2990 100644
--- a/docs/source/en/api/pipelines/audioldm.md
+++ b/docs/source/en/api/pipelines/audioldm.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# AudioLDM
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md
index c13288d489..d94281a4a9 100644
--- a/docs/source/en/api/pipelines/blip_diffusion.md
+++ b/docs/source/en/api/pipelines/blip_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# BLIP-Diffusion
BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
index 4e2d144421..40e290e4bd 100644
--- a/docs/source/en/api/pipelines/chroma.md
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
-pipe.enabe_model_cpu_offload()
+pipe.enable_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md
index 2eebcc6b74..aea8cb2e86 100644
--- a/docs/source/en/api/pipelines/controlnetxs.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
index 0862a5d798..76937b16c5 100644
--- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md
index 99deef37e1..dba807c5ce 100644
--- a/docs/source/en/api/pipelines/cosmos.md
+++ b/docs/source/en/api/pipelines/cosmos.md
@@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
+## Loading original format checkpoints
+
+Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
+
+```python
+import torch
+from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
+
+model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
+transformer = CosmosTransformer3DModel.from_single_file(
+ "https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
+negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
+
+output = pipe(
+ prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
+).images[0]
+output.save("output.png")
+```
+
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md
index 64a738f17c..5805561e49 100644
--- a/docs/source/en/api/pipelines/dance_diffusion.md
+++ b/docs/source/en/api/pipelines/dance_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Dance Diffusion
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md
index 02a76cf589..9734ca2eab 100644
--- a/docs/source/en/api/pipelines/diffedit.md
+++ b/docs/source/en/api/pipelines/diffedit.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# DiffEdit
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md
index eea7eeab19..76a51a6cd5 100644
--- a/docs/source/en/api/pipelines/i2vgenxl.md
+++ b/docs/source/en/api/pipelines/i2vgenxl.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# I2VGen-XL
[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.
diff --git a/docs/source/en/api/pipelines/musicldm.md b/docs/source/en/api/pipelines/musicldm.md
index 5072bcc4fb..c2297162f7 100644
--- a/docs/source/en/api/pipelines/musicldm.md
+++ b/docs/source/en/api/pipelines/musicldm.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MusicLDM
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
diff --git a/docs/source/en/api/pipelines/paint_by_example.md b/docs/source/en/api/pipelines/paint_by_example.md
index 769156643b..362c26de68 100644
--- a/docs/source/en/api/pipelines/paint_by_example.md
+++ b/docs/source/en/api/pipelines/paint_by_example.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Paint by Example
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md
index a9a95759d6..9f61388dd5 100644
--- a/docs/source/en/api/pipelines/panorama.md
+++ b/docs/source/en/api/pipelines/panorama.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MultiDiffusion
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md
index a58d7fbe8d..7bd480b49a 100644
--- a/docs/source/en/api/pipelines/pia.md
+++ b/docs/source/en/api/pipelines/pia.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Image-to-Video Generation with PIA (Personalized Image Animator)
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md
index f86cbc0b6f..5578fdfa63 100644
--- a/docs/source/en/api/pipelines/self_attention_guidance.md
+++ b/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Self-Attention Guidance
[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
index 99395e75a9..1ce44cf2de 100644
--- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md
+++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Semantic Guidance
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/gligen.md b/docs/source/en/api/pipelines/stable_diffusion/gligen.md
index 73be0b4ca8..e9704fc1de 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/gligen.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/gligen.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
index 4d7fda2a0c..75f052b08f 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# K-Diffusion
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
index 9f54538968..4c52ed90f0 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text-to-(RGB, depth)
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
index ac5b97b672..1736491107 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
diff --git a/docs/source/en/api/pipelines/text_to_video.md b/docs/source/en/api/pipelines/text_to_video.md
index 116aea736f..7faf88d133 100644
--- a/docs/source/en/api/pipelines/text_to_video.md
+++ b/docs/source/en/api/pipelines/text_to_video.md
@@ -10,11 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-
-
-🧪 This pipeline is for research purposes only.
-
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text-to-video
diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md
index 7966f43390..5fe3789d82 100644
--- a/docs/source/en/api/pipelines/text_to_video_zero.md
+++ b/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text2Video-Zero
diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index c9a3164226..8011a4b533 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index bce55b67ed..7d767f2db5 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# UniDiffuser
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 18b8207e3b..81cd242151 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
- from diffusers import WanPipeline, AutoModel
+ from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
- vae = AutoModel.from_single_file(
+ vae = AutoencoderKLWan.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
- transformer = AutoModel.from_single_file(
+ transformer = WanTransformer3DModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index 561df2017d..2be3631d84 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Würstchen
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index b18977720c..5a382c1c94 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
+If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
+
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
## Merge
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
new file mode 100644
index 0000000000..b5e55c27ca
--- /dev/null
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -0,0 +1,264 @@
+
+
+# Batch inference
+
+Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
+
+The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
+
+
+
+
+For text-to-image, pass a list of prompts to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+images = pipeline(
+ prompt=prompts,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+For image-to-image, pass a list of input images and prompts to the pipeline.
+
+```py
+import torch
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ guidance_scale=8.0,
+ strength=0.5
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ image=input_image,
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+## Deterministic generation
+
+Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
+
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
+
+Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+
+```py
+generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
+```
+
+Pass the `generator` to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ generator=generator
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index df3df92f06..11afbf29d3 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
-#### LoRA files
+#### LoRAs
-[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
+[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
-LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-# base model
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-# download LoRA weights
-!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
-
-# load LoRA weights
-pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
-prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
-negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
-
-image = pipeline(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.manual_seed(0),
-).images[0]
-image
-```
+The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
-

+
+For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
+
+```py
+import torch
+from diffusers import FluxPipeline
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
+pipeline.save_lora_weights(
+ transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
+ text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
+)
+```
+
### ckpt
> [!WARNING]
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index 60b8fee754..ac9350f24c 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
print("L_inf dist =", abs(result1 - result2).max())
"L_inf dist = tensor(0., device='cuda:0')"
```
-
-## Deterministic batch generation
-
-A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
-
-Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline = pipeline.to("cuda")
-```
-
-Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
-
-> [!WARNING]
-> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
->
-> ```py
-> [torch.Generator().manual_seed(seed)] * 4
-> ```
-
-```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
-prompt = "Labrador in the style of Vermeer"
-images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
-
-Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
-
-```python
-prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
-generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
-images = pipeline(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index a3efbf2e80..aabb9dd31c 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -242,3 +242,15 @@ unet = UNet2DConditionModel.from_pretrained(
)
unet.save_pretrained("./local-unet", variant="non_ema")
```
+
+Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
+
+```py
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+)
+```
+
+You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md
index 24c71d5c56..18273746c2 100644
--- a/examples/dreambooth/README_flux.md
+++ b/examples/dreambooth/README_flux.md
@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
## Training Kontext
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
-provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
+provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
-Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
+**important**
+
+> [!NOTE]
+> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
+> To do this, execute the following steps in a new virtual environment:
+> ```
+> git clone https://github.com/huggingface/diffusers
+> cd diffusers
+> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
+> pip install -e .
+> ```
Below is an example training command:
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
perform as expected.
+Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
+
+* Condition image
+* Target image
+* Instruction
+
+[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
+
+```bash
+accelerate launch train_dreambooth_lora_flux_kontext.py \
+ --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
+ --output_dir="kontext-i2i" \
+ --dataset_name="kontext-community/relighting" \
+ --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --optimizer="adamw" \
+ --use_8bit_adam \
+ --cache_latents \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=200 \
+ --max_train_steps=1000 \
+ --rank=16\
+ --seed="0"
+```
+
+More generally, when performing I2I fine-tuning, we expect you to:
+
+* Have a dataset `kontext-community/relighting`
+* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
+
### Misc notes
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
## Other notes
-Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
\ No newline at end of file
+Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
index 9f97567b06..5bd9b8684d 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
from torch.utils.data import Dataset
from torch.utils.data.sampler import BatchSampler
from torchvision import transforms
-from torchvision.transforms.functional import crop
+from torchvision.transforms import functional as TF
from tqdm.auto import tqdm
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
free_memory,
parse_buckets_string,
)
-from diffusers.utils import (
- check_min_version,
- convert_unet_state_dict_to_peft,
- is_wandb_available,
-)
+from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
@@ -186,6 +182,7 @@ def log_validation(
)
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
pipeline.set_progress_bar_config(disable=True)
+ pipeline_args_cp = pipeline_args.copy()
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
@@ -193,14 +190,16 @@ def log_validation(
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
with torch.no_grad():
- prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
- pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
- )
+ prompt = pipeline_args_cp.pop("prompt")
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
- prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
+ **pipeline_args_cp,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ generator=generator,
).images[0]
images.append(image)
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
"default, the standard Image Dataset maps out 'file_name' "
"to 'image'.",
)
+ parser.add_argument(
+ "--cond_image_column",
+ type=str,
+ default=None,
+ help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
+ )
parser.add_argument(
"--caption_column",
type=str,
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
- required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
+ )
parser.add_argument(
"--num_validation_images",
type=int,
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--output_dir",
type=str,
- default="flux-dreambooth-lora",
+ default="flux-kontext-lora",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
+ if args.cond_image_column is not None:
+ raise ValueError("Prior preservation isn't supported with I2I training.")
else:
# logger is not available yet
if args.class_data_dir is not None:
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+ if args.cond_image_column is not None:
+ assert args.image_column is not None
+ assert args.caption_column is not None
+ assert args.dataset_name is not None
+ assert not args.train_text_encoder
+ if args.validation_prompt is not None:
+ assert args.validation_image is None and os.path.exists(args.validation_image)
+
return args
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
repeats=1,
center_crop=False,
buckets=None,
+ args=None,
):
self.center_crop = center_crop
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
+ if args.cond_image_column is not None and args.cond_image_column not in column_names:
+ raise ValueError(
+ f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
if args.image_column is None:
image_column = column_names[0]
logger.info(f"image column defaulting to {image_column}")
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
raise ValueError(
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
- instance_images = dataset["train"][image_column]
+ instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
+ cond_images = None
+ cond_image_column = args.cond_image_column
+ if cond_image_column is not None:
+ cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
+ assert len(instance_images) == len(cond_images)
if args.caption_column is None:
logger.info(
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
self.custom_instance_prompts = None
self.instance_images = []
- for img in instance_images:
+ self.cond_images = []
+ for i, img in enumerate(instance_images):
self.instance_images.extend(itertools.repeat(img, repeats))
+ if args.dataset_name is not None and cond_images is not None:
+ self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
self.pixel_values = []
- for image in self.instance_images:
+ self.cond_pixel_values = []
+ for i, image in enumerate(self.instance_images):
image = exif_transpose(image)
if not image.mode == "RGB":
image = image.convert("RGB")
+ dest_image = None
+ if self.cond_images:
+ dest_image = exif_transpose(self.cond_images[i])
+ if not dest_image.mode == "RGB":
+ dest_image = dest_image.convert("RGB")
width, height = image.size
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
self.size = (target_height, target_width)
# based on the bucket assignment, define the transformations
- train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
- train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
- train_flip = transforms.RandomHorizontalFlip(p=1.0)
- train_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
+ image, dest_image = self.paired_transform(
+ image,
+ dest_image=dest_image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
)
- image = train_resize(image)
- if args.center_crop:
- image = train_crop(image)
- else:
- y1, x1, h, w = train_crop.get_params(image, self.size)
- image = crop(image, y1, x1, h, w)
- if args.random_flip and random.random() < 0.5:
- image = train_flip(image)
- image = train_transforms(image)
self.pixel_values.append((image, bucket_idx))
+ if dest_image is not None:
+ self.cond_pixel_values.append((dest_image, bucket_idx))
self.num_instance_images = len(self.instance_images)
self._length = self.num_instance_images
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
+ if self.cond_pixel_values:
+ dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
+ example["cond_images"] = dest_image
if self.custom_instance_prompts:
caption = self.custom_instance_prompts[index % self.num_instance_images]
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
return example
+ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+ if dest_image is not None:
+ dest_image = resize(dest_image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ if dest_image is not None:
+ dest_image = crop(dest_image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+ if dest_image is not None:
+ dest_image = TF.crop(dest_image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+ if dest_image is not None:
+ dest_image = TF.hflip(dest_image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+ if dest_image is not None:
+ dest_image = normalize(to_tensor(dest_image))
+
+ return (image, dest_image) if dest_image is not None else (image, None)
+
def collate_fn(examples, with_prior_preservation=False):
pixel_values = [example["instance_images"] for example in examples]
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
batch = {"pixel_values": pixel_values, "prompts": prompts}
+ if any("cond_images" in example for example in examples):
+ cond_pixel_values = [example["cond_images"] for example in examples]
+ cond_pixel_values = torch.stack(cond_pixel_values)
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
+ batch.update({"cond_pixel_values": cond_pixel_values})
return batch
@@ -1318,6 +1393,7 @@ def main(args):
"ff.net.2",
"ff_context.net.0.proj",
"ff_context.net.2",
+ "proj_mlp",
]
# now we will add new LoRA weights the transformer layers
@@ -1534,7 +1610,10 @@ def main(args):
buckets=buckets,
repeats=args.repeats,
center_crop=args.center_crop,
+ args=args,
)
+ if args.cond_image_column is not None:
+ logger.info("I2I fine-tuning enabled.")
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
@@ -1574,6 +1653,7 @@ def main(args):
# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
free_memory()
@@ -1605,19 +1685,41 @@ def main(args):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+ elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
+ cached_text_embeddings = []
+ for batch in tqdm(train_dataloader, desc="Embedding prompts"):
+ batch_prompts = batch["prompts"]
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ batch_prompts, text_encoders, tokenizers
+ )
+ cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
+
+ if args.validation_prompt is None:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
+ del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
+ free_memory()
+
vae_config_shift_factor = vae.config.shift_factor
vae_config_scaling_factor = vae.config.scaling_factor
vae_config_block_out_channels = vae.config.block_out_channels
+ has_image_input = args.cond_image_column is not None
if args.cache_latents:
latents_cache = []
+ cond_latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=weight_dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ if has_image_input:
+ batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
if args.validation_prompt is None:
+ vae.cpu()
del vae
free_memory()
@@ -1678,7 +1780,7 @@ def main(args):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
- tracker_name = "dreambooth-flux-dev-lora"
+ tracker_name = "dreambooth-flux-kontext-lora"
accelerator.init_trackers(tracker_name, config=vars(args))
# Train!
@@ -1742,6 +1844,7 @@ def main(args):
sigma = sigma.unsqueeze(-1)
return sigma
+ has_guidance = unwrap_model(transformer).config.guidance_embeds
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
@@ -1759,9 +1862,7 @@ def main(args):
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if not args.train_text_encoder:
- prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
- prompts, text_encoders, tokenizers
- )
+ prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
else:
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
tokens_two = tokenize_prompt(
@@ -1794,16 +1895,29 @@ def main(args):
if args.cache_latents:
if args.vae_encode_mode == "sample":
model_input = latents_cache[step].sample()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].sample()
else:
model_input = latents_cache[step].mode()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].mode()
else:
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ if has_image_input:
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
if args.vae_encode_mode == "sample":
model_input = vae.encode(pixel_values).latent_dist.sample()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
else:
model_input = vae.encode(pixel_values).latent_dist.mode()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
+ if has_image_input:
+ cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ cond_model_input = cond_model_input.to(dtype=weight_dtype)
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
@@ -1814,6 +1928,17 @@ def main(args):
accelerator.device,
weight_dtype,
)
+ if has_image_input:
+ cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
+ cond_model_input.shape[0],
+ cond_model_input.shape[2] // 2,
+ cond_model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ cond_latents_ids[..., 0] = 1
+ latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
+
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
@@ -1834,7 +1959,6 @@ def main(args):
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
-
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
noisy_model_input,
batch_size=model_input.shape[0],
@@ -1842,13 +1966,22 @@ def main(args):
height=model_input.shape[2],
width=model_input.shape[3],
)
+ orig_inp_shape = packed_noisy_model_input.shape
+ if has_image_input:
+ packed_cond_input = FluxKontextPipeline._pack_latents(
+ cond_model_input,
+ batch_size=cond_model_input.shape[0],
+ num_channels_latents=cond_model_input.shape[1],
+ height=cond_model_input.shape[2],
+ width=cond_model_input.shape[3],
+ )
+ packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
- # handle guidance
- if unwrap_model(transformer).config.guidance_embeds:
+ # Kontext always has guidance
+ guidance = None
+ if has_guidance:
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
guidance = guidance.expand(model_input.shape[0])
- else:
- guidance = None
# Predict the noise residual
model_pred = transformer(
@@ -1862,6 +1995,8 @@ def main(args):
img_ids=latent_image_ids,
return_dict=False,
)[0]
+ if has_image_input:
+ model_pred = model_pred[:, : orig_inp_shape[1]]
model_pred = FluxKontextPipeline._unpack_latents(
model_pred,
height=model_input.shape[2] * vae_scale_factor,
@@ -1970,6 +2105,8 @@ def main(args):
torch_dtype=weight_dtype,
)
pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
@@ -2030,6 +2167,8 @@ def main(args):
images = []
if args.validation_prompt and args.num_validation_images > 0:
pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
images = log_validation(
pipeline=pipeline,
args=args,
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 3c8b75a088..53ee0f89e2 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
- if args.train_norm_layers:
- for name, param in flux_transformer.named_parameters():
- if any(k in name for k in NORM_LAYER_PREFIXES):
- param.requires_grad = True
-
if args.lora_layers is not None:
if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
)
flux_transformer.add_adapter(transformer_lora_config)
+ if args.train_norm_layers:
+ for name, param in flux_transformer.named_parameters():
+ if any(k in name for k in NORM_LAYER_PREFIXES):
+ param.requires_grad = True
+
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py
index 0c0426a1ef..6f6563ad64 100644
--- a/scripts/convert_cosmos_to_diffusers.py
+++ b/scripts/convert_cosmos_to_diffusers.py
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
- # "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
diff --git a/setup.py b/setup.py
index 1efc698bba..936e2624c8 100644
--- a/setup.py
+++ b/setup.py
@@ -110,7 +110,7 @@ _deps = [
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
- "k-diffusion>=0.0.12",
+ "k-diffusion==0.0.12",
"torchsde",
"note_seq",
"librosa",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index b3f5f6ec9d..713472b4a5 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -133,9 +133,11 @@ else:
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
+ "FirstBlockCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
+ "apply_first_block_cache",
"apply_pyramid_attention_broadcast",
]
)
@@ -381,6 +383,7 @@ else:
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
+ "FluxKontextInpaintPipeline",
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
@@ -750,9 +753,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
HookRegistry,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
+ apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
from .models import (
@@ -975,6 +980,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index 5e06496a2d..1b2159afa2 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -17,7 +17,7 @@ deps = {
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
- "k-diffusion": "k-diffusion>=0.0.12",
+ "k-diffusion": "k-diffusion==0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 764ceb25b4..365bed3718 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -1,8 +1,23 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
+ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py
new file mode 100644
index 0000000000..3be77dd4ce
--- /dev/null
+++ b/src/diffusers/hooks/_common.py
@@ -0,0 +1,30 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..models.attention_processor import Attention, MochiAttention
+
+
+_ATTENTION_CLASSES = (Attention, MochiAttention)
+
+_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
+_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
+_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
+
+_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
+ {
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ }
+)
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
new file mode 100644
index 0000000000..960d14e6fa
--- /dev/null
+++ b/src/diffusers/hooks/_helpers.py
@@ -0,0 +1,264 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Type
+
+
+@dataclass
+class AttentionProcessorMetadata:
+ skip_processor_output_fn: Callable[[Any], Any]
+
+
+@dataclass
+class TransformerBlockMetadata:
+ return_hidden_states_index: int = None
+ return_encoder_hidden_states_index: int = None
+
+ _cls: Type = None
+ _cached_parameter_indices: Dict[str, int] = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if identifier in kwargs:
+ return kwargs[identifier]
+ if self._cached_parameter_indices is not None:
+ return args[self._cached_parameter_indices[identifier]]
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+ if identifier not in self._cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+ index = self._cached_parameter_indices[identifier]
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+ return args[index]
+
+
+class AttentionProcessorRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
+ cls._register()
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_attention_processors_metadata()
+
+
+class TransformerBlockRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
+ cls._register()
+ metadata._cls = model_class
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_transformer_blocks_metadata()
+
+
+def _register_attention_processors_metadata():
+ from ..models.attention_processor import AttnProcessor2_0
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
+
+ # AttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=AttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
+ ),
+ )
+
+ # CogView4AttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=CogView4AttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
+ ),
+ )
+
+
+def _register_transformer_blocks_metadata():
+ from ..models.attention import BasicTransformerBlock
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+ from ..models.transformers.transformer_hunyuan_video import (
+ HunyuanVideoSingleTransformerBlock,
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
+ HunyuanVideoTokenReplaceTransformerBlock,
+ HunyuanVideoTransformerBlock,
+ )
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
+ from ..models.transformers.transformer_wan import WanTransformerBlock
+
+ # BasicTransformerBlock
+ TransformerBlockRegistry.register(
+ model_class=BasicTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # CogVideoX
+ TransformerBlockRegistry.register(
+ model_class=CogVideoXBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # CogView4
+ TransformerBlockRegistry.register(
+ model_class=CogView4TransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Flux
+ TransformerBlockRegistry.register(
+ model_class=FluxTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=FluxSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+
+ # HunyuanVideo
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # LTXVideo
+ TransformerBlockRegistry.register(
+ model_class=LTXVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # Mochi
+ TransformerBlockRegistry.register(
+ model_class=MochiTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Wan
+ TransformerBlockRegistry.register(
+ model_class=WanTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+
+# fmt: off
+def _skip_attention___ret___hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ return hidden_states
+
+
+def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ if encoder_hidden_states is None and len(args) > 1:
+ encoder_hidden_states = args[1]
+ return hidden_states, encoder_hidden_states
+
+
+_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
+# fmt: on
diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py
new file mode 100644
index 0000000000..40ae8c5a26
--- /dev/null
+++ b/src/diffusers/hooks/first_block_cache.py
@@ -0,0 +1,227 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
+from ._helpers import TransformerBlockRegistry
+from .hooks import BaseState, HookRegistry, ModelHook, StateManager
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
+_FBC_BLOCK_HOOK = "fbc_block_hook"
+
+
+@dataclass
+class FirstBlockCacheConfig:
+ r"""
+ Configuration for [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
+
+ Args:
+ threshold (`float`, defaults to `0.05`):
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
+ is skipped.
+ """
+
+ threshold: float = 0.05
+
+
+class FBCSharedBlockState(BaseState):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.head_block_residual: torch.Tensor = None
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.should_compute: bool = True
+
+ def reset(self):
+ self.tail_block_residuals = None
+ self.should_compute = True
+
+
+class FBCHeadBlockHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(self, state_manager: StateManager, threshold: float):
+ self.state_manager = state_manager
+ self.threshold = threshold
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ is_output_tuple = isinstance(output, tuple)
+
+ if is_output_tuple:
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
+ else:
+ hidden_states_residual = output - original_hidden_states
+
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
+ hidden_states = encoder_hidden_states = None
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
+ shared_state.should_compute = should_compute
+
+ if not should_compute:
+ # Apply caching
+ if is_output_tuple:
+ hidden_states = (
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
+ )
+ else:
+ hidden_states = shared_state.tail_block_residuals[0] + output
+
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ assert is_output_tuple
+ encoder_hidden_states = (
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
+ )
+
+ if is_output_tuple:
+ return_output = [None] * len(output)
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
+ return_output = tuple(return_output)
+ else:
+ return_output = hidden_states
+ output = return_output
+ else:
+ if is_output_tuple:
+ head_block_output = [None] * len(output)
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
+ else:
+ head_block_output = output
+ shared_state.head_block_output = head_block_output
+ shared_state.head_block_residual = hidden_states_residual
+
+ return output
+
+ def reset_state(self, module):
+ self.state_manager.reset()
+ return module
+
+ @torch.compiler.disable
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
+ shared_state = self.state_manager.get_state()
+ if shared_state.head_block_residual is None:
+ return True
+ prev_hidden_states_residual = shared_state.head_block_residual
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
+ diff = (absmean / prev_hidden_states_absmean).item()
+ return diff > self.threshold
+
+
+class FBCBlockHook(ModelHook):
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
+ super().__init__()
+ self.state_manager = state_manager
+ self.is_tail = is_tail
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ original_encoder_hidden_states = None
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+
+ shared_state = self.state_manager.get_state()
+
+ if shared_state.should_compute:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ if self.is_tail:
+ hidden_states_residual = encoder_hidden_states_residual = None
+ if isinstance(output, tuple):
+ hidden_states_residual = (
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
+ )
+ encoder_hidden_states_residual = (
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
+ )
+ else:
+ hidden_states_residual = output - shared_state.head_block_output
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
+ return output
+
+ if original_encoder_hidden_states is None:
+ return_output = original_hidden_states
+ else:
+ return_output = [None, None]
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
+ return_output = tuple(return_output)
+ return return_output
+
+
+def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
+ state_manager = StateManager(FBCSharedBlockState, (), {})
+ remaining_blocks = []
+
+ for name, submodule in module.named_children():
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
+ continue
+ for index, block in enumerate(submodule):
+ remaining_blocks.append((f"{name}.{index}", block))
+
+ head_block_name, head_block = remaining_blocks.pop(0)
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
+
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
+
+ for name, block in remaining_blocks:
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
+ _apply_fbc_block_hook(block, state_manager)
+
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
+
+
+def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCHeadBlockHook(state_manager, threshold)
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
+
+
+def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCBlockHook(state_manager, is_tail)
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py
index 96231aadc3..6e097e5882 100644
--- a/src/diffusers/hooks/hooks.py
+++ b/src/diffusers/hooks/hooks.py
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
+from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
+class BaseState:
+ def reset(self, *args, **kwargs) -> None:
+ raise NotImplementedError(
+ "BaseState::reset is not implemented. Please implement this method in the derived class."
+ )
+
+
+class StateManager:
+ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
+ self._state_cls = state_cls
+ self._init_args = init_args if init_args is not None else ()
+ self._init_kwargs = init_kwargs if init_kwargs is not None else {}
+ self._state_cache = {}
+ self._current_context = None
+
+ def get_state(self):
+ if self._current_context is None:
+ raise ValueError("No context is set. Please set a context before retrieving the state.")
+ if self._current_context not in self._state_cache.keys():
+ self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
+ return self._state_cache[self._current_context]
+
+ def set_context(self, name: str) -> None:
+ self._current_context = name
+
+ def reset(self, *args, **kwargs) -> None:
+ for name, state in list(self._state_cache.items()):
+ state.reset(*args, **kwargs)
+ self._state_cache.pop(name)
+ self._current_context = None
+
+
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
+ def _set_context(self, module: torch.nn.Module, name: str) -> None:
+ # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
+ for attr_name in dir(self):
+ attr = getattr(self, attr_name)
+ if isinstance(attr, StateManager):
+ attr.set_context(name)
+ return module
+
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
- for module_name, module in self._module_ref.named_modules():
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
+ module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
+ def _set_context(self, name: Optional[str] = None) -> None:
+ for hook_name in reversed(self._hook_order):
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook._set_context(self._module_ref, name)
+
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
+ if module_name == "":
+ continue
+ module = unwrap_module(module)
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook._set_context(name)
+
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index 562a21dbbb..cd4738cfa0 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -934,6 +934,27 @@ class LoraBaseMixin:
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
+ GPU before using those LoRA adapters for inference.
+
+ ```python
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
+ >>> pipe.set_adapters("adapter-1")
+ >>> image_1 = pipe(**kwargs)
+ >>> # switch to adapter-2, offload adapter-1
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-2")
+ >>> image_2 = pipe(**kwargs)
+ >>> # switch back to adapter-1, offload adapter-2
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-1")
+ >>> ...
+ ```
+
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
@@ -949,6 +970,10 @@ class LoraBaseMixin:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
+ if adapter_name not in module.lora_A:
+ # it is sufficient to check lora_A
+ continue
+
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 25e06c007f..df3aa6212f 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict
+def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+ original_state_dict_keys = list(original_state_dict.keys())
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ original_block_prefix = "base_model.model."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norms
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
+ )
+
+ # Q, K, V
+ if lora_key == "lora_A":
+ sample_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+
+ context_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
+
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
+
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
+
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
+ )
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+
+ if lora_key == "lora_A":
+ lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
+ else:
+ q, k, v, mlp = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
+ )
+
+ for lora_key in ["lora_A", "lora_B"]:
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
@@ -1603,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
+ has_time_projection_weight = any(
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
+ )
- diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
- if diff_keys:
- for diff_k in diff_keys:
- param = original_state_dict[diff_k]
- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
- # is okay to ignore because they do not affect the model output in a significant manner.
- threshold = 1.6e-2
- absdiff = param.abs().max() - param.abs().min()
- all_zero = torch.all(param == 0).item()
- all_absdiff_lower_than_threshold = absdiff < threshold
- if all_zero or all_absdiff_lower_than_threshold:
- logger.debug(
- f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
- )
- original_state_dict.pop(diff_k)
+ for key in list(original_state_dict.keys()):
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
+ # in future if needed and they are not zeroed.
+ original_state_dict.pop(key)
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
+
+ if "time_projection" in key and not has_time_projection_weight:
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
+ # CausVid lora has the weight keys and the bias keys.
+ original_state_dict.pop(key)
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 4fea005cbc..4ee4808d80 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
+ _convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return_metadata=return_lora_metadata,
)
+ is_fal_kontext = any("base_model" in k for k in state_dict)
+ if is_fal_kontext:
+ state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
+
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 3670243de8..4ade3374d8 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -244,13 +244,20 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
- # create LoraConfig
- lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
-
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
+ # create LoraConfig
+ lora_config = _create_lora_config(
+ state_dict,
+ network_alphas,
+ metadata,
+ rank,
+ model_state_dict=self.state_dict(),
+ adapter_name=adapter_name,
+ )
+
#
None:
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
+ from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
- elif isinstance(self._cache_config, FasterCacheConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
+
+ @contextmanager
+ def cache_context(self, name: str):
+ r"""Context manager that provides additional methods for cache management."""
+ from ..hooks import HookRegistry
+
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry._set_context(name)
+
+ yield
+
+ registry._set_context(None)
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
index d8e99ee45e..063ff5bd8e 100644
--- a/src/diffusers/models/controlnets/controlnet_flux.py
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+ single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index e4144d0c8e..dc45befb98 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
+@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py
index 6c312b7a5a..3a6cb1ce6e 100644
--- a/src/diffusers/models/transformers/transformer_cosmos.py
+++ b/src/diffusers/models/transformers/transformer_cosmos.py
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -377,7 +378,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
return (emb / norm).type_as(hidden_states)
-class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
+class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index 3af1de2ad0..3a7202d0f4 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -79,10 +79,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +104,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
- return hidden_states
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -507,20 +512,21 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -530,12 +536,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
- + controlnet_single_block_samples[index_block // interval_control]
- )
-
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index 0ae7f2c00d..bdb9201e62 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -22,6 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -71,14 +72,22 @@ class WanAttnProcessor2_0:
if rotary_emb is not None:
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
- x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
+ x1, x2 = x[..., 0], x[..., 1]
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
@@ -179,7 +188,11 @@ class WanTimeTextImageEmbedding(nn.Module):
class WanRotaryPosEmbed(nn.Module):
def __init__(
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
):
super().__init__()
@@ -189,38 +202,55 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
-
- freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
for dim in [t_dim, h_dim, w_dim]:
- freq = get_1d_rotary_pos_embed(
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
)
- freqs.append(freq)
- self.freqs = torch.cat(freqs, dim=1)
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- freqs = self.freqs.to(hidden_states.device)
- freqs = freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+
+ return freqs_cos, freqs_sin
+@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 892c6f5a4c..1904c02999 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -141,6 +141,7 @@ else:
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
"FluxKontextPipeline",
+ "FluxKontextInpaintPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -610,6 +611,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index f08a3c35c2..3c5994172c 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index fe3e8ae388..cf6ccebc47 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index a982f4b275..d1f02ca9c9 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- ofs=ofs_emb,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ ofs=ofs_emb,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 7c50bdcb7d..230c8ca296 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 880253459e..d8374b694f 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
- noise_pred_cond = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- original_size=original_size,
- target_size=target_size,
- crop_coords=crops_coords_top_left,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- # perform guidance
- if self.do_classifier_free_guidance:
- noise_pred_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 7d6a29ceca..598e3b5b6d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -29,7 +29,7 @@ from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
"""
-class BlipDiffusionControlNetPipeline(DiffusionPipeline):
+class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
@@ -116,6 +116,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 117ce46f20..ea25c148e2 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -34,6 +34,7 @@ else:
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
_import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
+ _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -54,6 +55,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
from .pipeline_flux_kontext import FluxKontextPipeline
+ from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index 4c83ae7405..073d94750a 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -912,32 +912,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- if negative_image_embeds is not None:
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
- pooled_projections=negative_pooled_prompt_embeds,
- encoder_hidden_states=negative_prompt_embeds,
- txt_ids=negative_text_ids,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index b4f77cf019..ea49821adc 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -163,9 +163,9 @@ class FluxControlPipeline(
TextualInversionLoaderMixin,
):
r"""
- The Flux pipeline for controllable text-to-image generation.
+ The Flux pipeline for controllable text-to-image generation with image conditions.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/flux-1-tools
Args:
transformer ([`FluxTransformer2DModel`]):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
index 07b9b895a4..94901ee0b6 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
@@ -195,9 +195,9 @@ class FluxKontextPipeline(
FluxIPAdapterMixin,
):
r"""
- The Flux Kontext pipeline for text-to-image generation.
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
Args:
transformer ([`FluxTransformer2DModel`]):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..2b4abe8b24
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,1459 @@
+# Copyright 2025 ZenAI. All rights reserved.
+# author: @vuongminh1907
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ # Inpainting with text only
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> prompt = "Change the yellow dinosaur to green one"
+ >>> img_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
+ ... )
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
+ >>> image.save("kontext_inpainting_normal.png")
+ ```
+
+ # Inpainting with image conditioning
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "Replace this ball"
+ >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
+ ... )
+ >>> image_reference_url = (
+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image_reference = load_image(image_reference_url)
+
+ >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
+ >>> image = pipe(
+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
+ ... ).images[0]
+ >>> image.save("kontext_inpainting_ref.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ timestep: int,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ image_reference: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ # Prepare image latents
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ # Prepare image reference latents
+ image_reference_latents = image_reference_ids = None
+ if image_reference is not None:
+ image_reference = image_reference.to(device=device, dtype=dtype)
+ if image_reference.shape[1] != self.latent_channels:
+ image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator)
+ else:
+ image_reference_latents = image_reference
+ if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_reference_latents.shape[0]
+ image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_reference_latents = torch.cat([image_reference_latents], dim=0)
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device=device, dtype=dtype)
+ latents = noise
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ if image_reference_latents is not None:
+ image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
+ image_reference_latents = self._pack_latents(
+ image_reference_latents,
+ batch_size,
+ num_channels_latents,
+ image_reference_latent_height,
+ image_reference_latent_width,
+ )
+ image_reference_ids = self._prepare_latent_image_ids(
+ batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
+ )
+ # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_reference_ids[..., 0] = 1
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ image_reference: Optional[PipelineImageInput] = None,
+ mask_image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ padding_mask_crop: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+
+ # Choose the resolution of the image to be the same as the image
+ width = image_width
+ height = image_height
+
+ # 2.1 Preprocess mask
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ image = self.image_processor.preprocess(
+ image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ else:
+ raise ValueError("image must be provided correctly for inpainting")
+
+ init_image = image.to(dtype=torch.float32)
+
+ # 2.1 Preprocess image_reference
+ if image_reference is not None and not (
+ isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
+ ):
+ if (
+ isinstance(image_reference, list)
+ and isinstance(image_reference[0], torch.Tensor)
+ and image_reference[0].ndim == 4
+ ):
+ image_reference = torch.cat(image_reference, dim=0)
+ img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
+ image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
+ img_reference
+ )
+ aspect_ratio = image_reference_width / image_reference_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_reference_width, image_reference_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_reference_width = image_reference_width // multiple_of * multiple_of
+ image_reference_height = image_reference_height // multiple_of * multiple_of
+ image_reference = self.image_processor.resize(
+ image_reference, image_reference_height, image_reference_width
+ )
+ image_reference = self.image_processor.preprocess(
+ image_reference,
+ image_reference_height,
+ image_reference_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ else:
+ image_reference = None
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
+ self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image_reference,
+ )
+ )
+
+ if image_reference_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension
+ elif image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * (mask_condition < 0.5)
+
+ mask, _ = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_reference_latents is not None:
+ latent_model_input = torch.cat([latents, image_reference_latents], dim=1)
+ elif image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index b617e4f8b2..2cbb4af2b4 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -693,28 +693,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=prompt_attention_mask,
- pooled_projections=pooled_prompt_embeds,
- guidance=guidance,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
- encoder_attention_mask=negative_prompt_attention_mask,
- pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 3b58b4a45a..77ba751700 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index fa9ee4fc7b..217478f418 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- video_coords=video_coords,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 99412b6962..8793d81377 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index 7712b41524..3c0f908296 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
index 6df66118b0..d14dac91f1 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -533,22 +533,24 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
index acff268c9b..63b4a109ff 100644
--- a/src/diffusers/schedulers/scheduling_scm.py
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
- print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 2981f3a420..6d25047a0f 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -17,6 +17,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class FirstBlockCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -51,6 +66,10 @@ def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
+def apply_first_block_cache(*args, **kwargs):
+ requires_backends(apply_first_block_cache, ["torch"])
+
+
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index a0c6d84a32..9cb869c67a 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -692,6 +692,21 @@ class FluxInpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class FluxKontextInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxKontextPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index 3907bdd5b3..651fa27294 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
-def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
+def get_peft_kwargs(
+ rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
+):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()
- # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"use_dora": use_dora,
"lora_bias": lora_bias,
}
+
+ # Example: try load FusionX LoRA into Wan VACE
+ exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
+ if exclude_modules:
+ if not is_peft_version(">=", "0.14.0"):
+ msg = """
+It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
+version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
+peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
+https://github.com/huggingface/diffusers/issues/new
+ """
+ logger.debug(msg)
+ else:
+ lora_config_kwargs.update({"exclude_modules": exclude_modules})
+
return lora_config_kwargs
@@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None:
def _create_lora_config(
- state_dict,
- network_alphas,
- metadata,
- rank_pattern_dict,
- is_unet: bool = True,
+ state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
from peft import LoraConfig
@@ -306,7 +318,12 @@ def _create_lora_config(
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
- rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
+ rank_pattern_dict,
+ network_alpha_dict=network_alphas,
+ peft_state_dict=state_dict,
+ is_unet=is_unet,
+ model_state_dict=model_state_dict,
+ adapter_name=adapter_name,
)
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
+
+
+def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
+ """
+ Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
+ `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
+ doesn't exist in `peft_state_dict`.
+ """
+ if model_state_dict is None:
+ return
+ all_modules = set()
+ string_to_replace = f"{adapter_name}." if adapter_name else ""
+
+ for name in model_state_dict.keys():
+ if string_to_replace:
+ name = name.replace(string_to_replace, "")
+ if "." in name:
+ module_name = name.rsplit(".", 1)[0]
+ all_modules.add(module_name)
+
+ target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
+ exclude_modules = list(all_modules - target_modules_set)
+
+ return exclude_modules
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index e5da39c1d8..a136d1b6bd 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -421,6 +421,10 @@ def require_big_accelerator(test_case):
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
Flux, SD3, Cog, etc.
"""
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
@@ -990,10 +994,10 @@ def pytest_terminal_summary_main(tr, id):
config.option.tbstyle = orig_tbstyle
-# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
+# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
- To decorate flaky tests. They will be retried on failures.
+ To decorate flaky tests (methods or entire classes). They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
@@ -1005,22 +1009,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
etc.)
"""
- def decorator(test_func_ref):
- @functools.wraps(test_func_ref)
+ def decorator(obj):
+ # If decorating a class, wrap each test method on it
+ if inspect.isclass(obj):
+ for attr_name, attr_value in list(obj.__dict__.items()):
+ if callable(attr_value) and attr_name.startswith("test"):
+ # recursively decorate the method
+ setattr(obj, attr_name, decorator(attr_value))
+ return obj
+
+ # Otherwise we're decorating a single test function / method
+ @functools.wraps(obj)
def wrapper(*args, **kwargs):
retry_count = 1
-
while retry_count < max_attempts:
try:
- return test_func_ref(*args, **kwargs)
-
+ return obj(*args, **kwargs)
except Exception as err:
- print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
+ msg = (
+ f"[FLAKY] {description or obj.__name__!r} "
+ f"failed on attempt {retry_count}/{max_attempts}: {err}"
+ )
+ print(msg, file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
- return test_func_ref(*args, **kwargs)
+ return obj(*args, **kwargs)
return wrapper
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index ffc1119727..61a5d95b69 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+def unwrap_module(module):
+ """Unwraps a module if it was compiled with torch.compile()"""
+ return module._orig_mod if is_compiled_module(module) else module
+
+
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
diff --git a/tests/conftest.py b/tests/conftest.py
index 7e9c4e8f39..3237fb9c7b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,6 +30,10 @@ sys.path.insert(1, git_repo_path)
warnings.simplefilter(action="ignore", category=FutureWarning)
+def pytest_configure(config):
+ config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
+
+
def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 336ac2246f..95f1e137e9 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -20,7 +20,6 @@ import tempfile
import unittest
import numpy as np
-import pytest
import safetensors.torch
import torch
from parameterized import parameterized
@@ -813,7 +812,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -960,7 +958,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index 19e31f320d..4cbd6523e7 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -198,7 +197,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index a81128fa44..1c5a9b00e9 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -120,7 +120,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
# We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
for name, param in pipe.unet.named_parameters():
@@ -208,6 +208,53 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
+ @slow
+ @require_torch_accelerator
+ def test_integration_set_lora_device_different_target_layers(self):
+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
+ # layers, see #11833
+ from peft import LoraConfig
+
+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+ # configs partly target the same, partly different layers
+ config0 = LoraConfig(target_modules=["to_k", "to_v"])
+ config1 = LoraConfig(target_modules=["to_k", "to_q"])
+ pipe.unet.add_adapter(config0, adapter_name="adapter-0")
+ pipe.unet.add_adapter(config1, adapter_name="adapter-1")
+ pipe = pipe.to(torch_device)
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.unet),
+ "Lora not correctly set in unet",
+ )
+
+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
+ modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
+ modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
+ self.assertNotEqual(modules_adapter_0, modules_adapter_1)
+ self.assertTrue(modules_adapter_0 - modules_adapter_1)
+ self.assertTrue(modules_adapter_1 - modules_adapter_0)
+
+ # setting both separately works
+ pipe.set_lora_device(["adapter-0"], "cpu")
+ pipe.set_lora_device(["adapter-1"], "cpu")
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+
+ # setting both at once also works
+ pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+
@slow
@nightly
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 8a8f2a676d..8928ccbac2 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -139,7 +138,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index 95ec44b2bf..fe26a56e77 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -24,7 +24,11 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ skip_mps,
+)
sys.path.append(".")
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
new file mode 100644
index 0000000000..f976577653
--- /dev/null
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -0,0 +1,222 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+import safetensors.torch
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
+from diffusers.utils.import_utils import is_peft_available
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ is_flaky,
+ require_peft_backend,
+ require_peft_version_greater,
+ skip_mps,
+ torch_device,
+)
+
+
+if is_peft_available():
+ from peft.utils import get_peft_model_state_dict
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+@is_flaky(max_attempts=10, description="very flaky class")
+class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = WanVACEPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 32,
+ "freq_dim": 16,
+ "ffn_dim": 16,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 16,
+ "vace_layers": [0],
+ "vace_in_channels": 72,
+ }
+ transformer_cls = WanVACETransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 4,
+ "dim_mult": [1, 1, 1, 1],
+ "latents_mean": torch.randn(4).numpy().tolist(),
+ "latents_std": torch.randn(4).numpy().tolist(),
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 16, 16, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+ height, width = 16, 16
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ video = [Image.new("RGB", (height, width))] * num_frames
+ mask = [Image.new("L", (height, width), 0)] * num_frames
+
+ pipeline_inputs = {
+ "video": video,
+ "mask": mask,
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": height,
+ "width": height,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ @pytest.mark.xfail(
+ condition=True,
+ reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same",
+ strict=True,
+ )
+ def test_layerwise_casting_inference_denoiser(self):
+ super().test_layerwise_casting_inference_denoiser()
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules_wanvace(self):
+ scheduler_cls = self.scheduler_classes[0]
+ exclude_module_name = "vace_blocks.0.proj_out"
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ denoiser_lora_config.target_modules = ["proj_out"]
+ denoiser_lora_config.exclude_modules = [exclude_module_name]
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ # The state dict shouldn't contain the modules to be excluded from LoRA.
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
+ pipe.unload_lora_weights()
+
+ # Check in the loaded state dict.
+ loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
+ self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
+
+ # Check in the state dict obtained after loading LoRA.
+ pipe.load_lora_weights(tmpdir)
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
+ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
+ super().test_simple_inference_with_text_denoiser_lora_and_scale()
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index acd6f5f343..91ca188137 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import os
import re
@@ -291,9 +292,21 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
- def check_if_adapters_added_correctly(
- self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
- ):
+ def _get_exclude_modules(self, pipe):
+ from diffusers.utils.peft_utils import _derive_exclude_modules
+
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ denoiser = "unet" if self.unet_kwargs is not None else "transformer"
+ modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
+ denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
+ pipe.unload_lora_weights()
+ denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
+ exclude_modules = _derive_exclude_modules(
+ denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
+ )
+ return exclude_modules
+
+ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
@@ -345,7 +358,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -428,7 +441,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -484,7 +497,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -522,7 +535,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
@@ -554,7 +567,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -589,7 +602,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -640,7 +653,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -691,7 +704,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -734,7 +747,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -775,7 +788,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -819,7 +832,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
@@ -857,7 +870,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -893,7 +906,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
@@ -1010,7 +1023,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1032,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1759,7 +1772,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1850,7 +1863,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
@@ -1937,7 +1950,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
@@ -2119,7 +2132,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2237,7 +2250,7 @@ class PeftLoraLoaderMixinTests:
)
pipe = self.pipeline_class(**components)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
@@ -2290,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2309,6 +2322,77 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
+ def test_lora_unload_add_adapter(self):
+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # unload and then add.
+ pipe.unload_lora_weights()
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules(self):
+ """
+ Test to check if `exclude_modules` works or not. It works in the following way:
+ we first create a pipeline and insert LoRA config into it. We then derive a `set`
+ of modules to exclude by investigating its denoiser state dict and denoiser LoRA
+ state dict.
+
+ We then create a new LoRA config to include the `exclude_modules` and perform tests.
+ """
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ pipe_cp = copy.deepcopy(pipe)
+ pipe_cp, _ = self.add_adapters_to_pipeline(
+ pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
+ pipe_cp.to("cpu")
+ del pipe_cp
+
+ denoiser_lora_config.exclude_modules = denoiser_exclude_modules
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(tmpdir)
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index dcc7ae16a4..def81ecd64 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -1350,7 +1350,6 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
- print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin:
"""
+ different_shapes_for_compilation = None
+
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2056,11 +2057,13 @@ class LoraHotSwappingForModelTesterMixin:
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
+ - optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
+ different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2110,19 +2113,30 @@ class LoraHotSwappingForModelTesterMixin:
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
- model = torch.compile(model, mode="reduce-overhead")
+ model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
- output0_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+ # additionally check if dynamic compilation works.
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output0_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
- output1_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output1_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
@@ -2240,3 +2254,23 @@ class LoraHotSwappingForModelTesterMixin:
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)])
+ @require_torch_version_greater("2.7.1")
+ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
+ different_shapes_for_compilation = self.different_shapes_for_compilation
+ if different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
+ # variable to represent input sizes that are the same. For more details,
+ # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+ torch.fx.experimental._config.use_duck_shape = False
+
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(
+ do_compile=True,
+ rank0=rank0,
+ rank1=rank1,
+ target_modules0=target_modules,
+ )
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 4552b2e1f5..68b5c02bc0 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -186,6 +186,10 @@ class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
deleted file mode 100644
index 94759d1f20..0000000000
--- a/tests/pipelines/amused/test_amused.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=8,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "height": 4,
- "width": 4,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.4011, 0.3992, 0.379, 0.3856, 0.3772, 0.3711, 0.3919, 0.385, 0.3625])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_256_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158])
- assert np.abs(image_slice - expected_slice).max() < 0.007
-
- def test_amused_512(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1199, 0.1171, 0.1229, 0.1188, 0.1210, 0.1147, 0.1260, 0.1346, 0.1152])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_512_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1509, 0.1492, 0.1531, 0.1485, 0.1501, 0.1465, 0.1581, 0.1690, 0.1499])
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
deleted file mode 100644
index a76d82a2f0..0000000000
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedImg2ImgPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.999, 0.9954, 1.0])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_256_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256", torch_dtype=torch.float16, variant="fp16")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.998, 0.998, 0.994, 0.9944, 0.996, 0.9908, 1.0, 1.0, 0.9986])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_512(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2809, 0.1879, 0.2027, 0.2418, 0.1852, 0.2145, 0.2484, 0.2425, 0.2317])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2795, 0.1867, 0.2028, 0.2450, 0.1856, 0.2140, 0.2473, 0.2406, 0.2313])
- assert np.abs(image_slice - expected_slice).max() < 0.1
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
deleted file mode 100644
index 0b025b8a3f..0000000000
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ /dev/null
@@ -1,281 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- Expectations,
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedInpaintPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=32,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image[0, 0, 0, 0] = 0
- mask_image[0, 0, 0, 1] = 0
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- "mask_image": mask_image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedInpaintPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_256_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0735, 0.0749, 0.065, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0])
- assert np.abs(image_slice - expected_slice).max() < 0.05
-
- def test_amused_512_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slices = Expectations(
- {
- ("xpu", 3): np.array(
- [
- 0.0274,
- 0.0211,
- 0.0154,
- 0.0257,
- 0.0299,
- 0.0170,
- 0.0326,
- 0.0420,
- 0.0150,
- ]
- ),
- ("cuda", 7): np.array(
- [
- 0.0227,
- 0.0157,
- 0.0098,
- 0.0213,
- 0.0250,
- 0.0127,
- 0.0280,
- 0.0380,
- 0.0095,
- ]
- ),
- }
- )
- expected_slice = expected_slices.get_expectation()
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py
deleted file mode 100644
index eb4139f0dc..0000000000
--- a/tests/pipelines/audioldm/test_audioldm.py
+++ /dev/null
@@ -1,461 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from transformers import (
- ClapTextConfig,
- ClapTextModelWithProjection,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AudioLDMPipeline,
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AudioLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- layers_per_block=1,
- norm_num_groups=8,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(8, 16),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=8,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[8, 16],
- in_channels=1,
- out_channels=1,
- norm_num_groups=8,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = ClapTextModelWithProjection(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
-
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_audioldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = audioldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- prompt_embeds = prompt_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- prompt_embeds = F.normalize(prompt_embeds, dim=-1)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = audioldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- text_embeds = text_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- text_embeds = F.normalize(text_embeds, dim=-1)
-
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = audioldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = audioldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_audioldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_audioldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = audioldm_pipe.vocoder.config
- config.model_in_dim *= 2
- audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
-
-@nightly
-class AudioLDMPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[77230:77240]
- expected_slice = np.array(
- [-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-2
-
-
-@nightly
-class AudioLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm_lms(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[27780:27790]
- expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 3e-2
diff --git a/tests/pipelines/blipdiffusion/__init__.py b/tests/pipelines/blipdiffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
deleted file mode 100644
index 0e3f723fc6..0000000000
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.utils.testing_utils import enable_full_determinism
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=8,
- intermediate_size=8,
- projection_dim=8,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(8,),
- norm_num_groups=8,
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- sample_size=8,
- )
-
- blip_vision_config = {
- "hidden_size": 8,
- "intermediate_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 8,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 8,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=8,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- norm_num_groups=8,
- layers_per_block=1,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_blipdiffusion(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
-
- expected_slice = np.array(
- [0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
- )
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index c725589781..a6cb558513 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -45,7 +46,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
deleted file mode 100644
index 100082b6f0..0000000000
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import (
- AutoencoderKL,
- BlipDiffusionControlNetPipeline,
- ControlNetModel,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionControlNetPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=16,
- intermediate_size=16,
- projection_dim=16,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(32,),
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- norm_num_groups=16,
- sample_size=16,
- )
-
- blip_vision_config = {
- "hidden_size": 16,
- "intermediate_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 16,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 16,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=16,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- norm_num_groups=4,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=16,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
- controlnet = ControlNetModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- in_channels=4,
- norm_num_groups=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- cross_attention_dim=16,
- conditioning_embedding_out_channels=(8, 16),
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "controlnet": controlnet,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
- cond_image = np.random.rand(32, 32, 3) * 255
- cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "condtioning_image": cond_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.4803, 0.3865, 0.1422, 0.6119, 0.2283, 0.6365, 0.5453, 0.5205, 0.3581])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_blipdiffusion_controlnet(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
- expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 5ee94b09ba..5b336edc7a 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -17,7 +17,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -211,7 +210,6 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 712c26b0a2..1f1f800bcf 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -18,7 +18,6 @@ import unittest
from typing import Optional
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -221,7 +220,6 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline
diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
deleted file mode 100644
index 6f8422797c..0000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs.py
+++ /dev/null
@@ -1,352 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- DDIMScheduler,
- LCMScheduler,
- StableDiffusionControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_image,
- require_accelerator,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class ControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self, time_cond_proj_dim=None):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- norm_num_groups=4,
- time_cond_proj_dim=time_cond_proj_dim,
- use_linear_projection=True,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=1,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "numpy",
- "image": image,
- }
-
- return inputs
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- def test_controlnet_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=8)
- sd_pipe = StableDiffusionControlNetXSPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 16, 16, 3)
- expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@require_torch_accelerator
-class ControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (768, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (512, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
- assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
deleted file mode 100644
index 24a8b9cd57..0000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
+++ /dev/null
@@ -1,393 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- EulerDiscreteScheduler,
- StableDiffusionXLControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_image,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionXLControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionXLControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- use_linear_projection=True,
- norm_num_groups=4,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
- cross_attention_dim=8,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=0.5,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = EulerDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- steps_offset=1,
- beta_schedule="scaled_linear",
- timestep_spacing="leading",
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=8,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- }
- return components
-
- # Copied from test_controlnet_sdxl.py
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "image": image,
- }
-
- return inputs
-
- # Copied from test_controlnet_sdxl.py
- def test_attention_slicing_forward_pass(self):
- return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- # Copied from test_controlnet_sdxl.py
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- # Copied from test_controlnet_sdxl.py
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- @unittest.skip("We test this functionality elsewhere already.")
- def test_save_load_optional_components(self):
- pass
-
- @require_torch_accelerator
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- pipe.unet.set_default_attn_processor()
-
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_multi_prompts(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
-
- # forward with single prompt
- inputs = self.get_dummy_inputs(torch_device)
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = inputs["prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "different prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # manually set a negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same negative_prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = inputs["negative_prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = "different negative prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # Copied from test_controlnetxs.py
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (768, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (512, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
- assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/dance_diffusion/__init__.py b/tests/pipelines/dance_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py
deleted file mode 100644
index a2a1753214..0000000000
--- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = DanceDiffusionPipeline
- params = UNCONDITIONAL_AUDIO_GENERATION_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {
- "callback",
- "latents",
- "callback_steps",
- "output_type",
- "num_images_per_prompt",
- }
- batch_params = UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS
- test_attention_slicing = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet1DModel(
- block_out_channels=(32, 32, 64),
- extra_in_channels=16,
- sample_size=512,
- sample_rate=16_000,
- in_channels=2,
- out_channels=2,
- flip_sin_to_cos=True,
- use_timestep_embedding=False,
- time_embedding_type="fourier",
- mid_block_type="UNetMidBlock1D",
- down_block_types=("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
- up_block_types=("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
- )
- scheduler = IPNDMScheduler()
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "batch_size": 1,
- "generator": generator,
- "num_inference_steps": 4,
- }
- return inputs
-
- def test_dance_diffusion(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = DanceDiffusionPipeline(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = pipe(**inputs)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, components["unet"].sample_size)
- expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local()
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- return super().test_attention_slicing_forward_pass()
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_accelerator
-class PipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_dance_diffusion(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_dance_diffusion_fp16(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index cbdf617d71..0df0e028ff 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -25,6 +24,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
@@ -34,11 +34,12 @@ from ..test_pipelines_common import (
class FluxPipelineFastTests(
- unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -224,7 +225,6 @@ class FluxPipelineFastTests(
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -312,7 +312,6 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..615209264d
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,190 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextInpaintPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor, torch_device
+
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextInpaintPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextInpaintPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_inpaint_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+ # Because output shape is the same as the input shape, we need to create a dummy image and mask image
+ image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
+ mask_image = torch.ones((1, 1, height, width)).to(torch_device)
+
+ inputs.update(
+ {
+ "height": height,
+ "width": width,
+ "max_area": height * width,
+ "image": image,
+ "mask_image": mask_image,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
index b8f36dfd3c..b73050a64d 100644
--- a/tests/pipelines/flux/test_pipeline_flux_redux.py
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
@@ -19,7 +18,6 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
index ecc5eba964..10101af75c 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +44,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
diff --git a/tests/pipelines/i2vgen_xl/__init__.py b/tests/pipelines/i2vgen_xl/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
deleted file mode 100644
index bedd63738a..0000000000
--- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import pytest
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- I2VGenXLPipeline,
-)
-from diffusers.models.unets import I2VGenXLUNet
-from diffusers.utils import is_xformers_available, load_image
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- is_torch_version,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase):
- pipeline_class = I2VGenXLPipeline
- params = frozenset(["prompt", "negative_prompt", "image"])
- batch_params = frozenset(["prompt", "negative_prompt", "image", "generator"])
- # No `output_type`.
- required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
-
- supports_dduf = False
- test_layerwise_casting = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- torch.manual_seed(0)
- unet = I2VGenXLUNet(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- num_attention_heads=None,
- norm_num_groups=2,
- )
-
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- torch.manual_seed(0)
- vision_encoder_config = CLIPVisionConfig(
- hidden_size=4,
- projection_dim=4,
- num_hidden_layers=2,
- num_attention_heads=2,
- image_size=32,
- intermediate_size=16,
- patch_size=1,
- )
- image_encoder = CLIPVisionModelWithProjection(vision_encoder_config)
-
- torch.manual_seed(0)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "image": input_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- "num_frames": 4,
- "width": 32,
- "height": 32,
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @pytest.mark.xfail(
- condition=is_torch_version(">=", "2.7"),
- reason="Test currently fails on PyTorch 2.7.",
- strict=False,
- )
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=0.006)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent(expected_max_difference=0.009)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.008)
-
- @unittest.skip("Deprecated functionality")
- def test_attention_slicing_forward_pass(self):
- pass
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=0.008)
-
- def test_model_cpu_offload_forward_pass(self):
- super().test_model_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_num_videos_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs, num_videos_per_prompt=2).frames
-
- assert frames.shape == (2, 4, 32, 32, 3)
- assert frames[0][0].shape == (32, 32, 3)
-
- image_slice = frames[0][0][-3:, -3:, -1]
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skip("Test not supported for now.")
- def test_encode_prompt_works_in_isolation(self):
- pass
-
-
-@slow
-@require_torch_accelerator
-class I2VGenXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_i2vgen_xl(self):
- pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
- )
-
- generator = torch.Generator("cpu").manual_seed(0)
- num_frames = 3
-
- output = pipe(
- image=image,
- prompt="my cat",
- num_frames=num_frames,
- generator=generator,
- num_inference_steps=3,
- output_type="np",
- )
-
- image = output.frames[0]
- assert image.shape == (num_frames, 704, 1280, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5482, 0.6244, 0.6274, 0.4584, 0.5935, 0.5937, 0.4579, 0.5767, 0.5892])
- assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
index 1d1eb08234..bf0c7fde59 100644
--- a/tests/pipelines/ltx/test_ltx.py
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
- num_layers=1,
+ num_layers=num_layers,
caption_channels=32,
)
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index 5b00261b06..f1684cce72 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -17,7 +17,6 @@ import inspect
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -33,13 +32,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
+from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
+class MochiPipelineFastTests(
+ PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
+):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -268,7 +269,6 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte
@nightly
@require_torch_accelerator
@require_big_accelerator
-@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
diff --git a/tests/pipelines/musicldm/__init__.py b/tests/pipelines/musicldm/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py
deleted file mode 100644
index 5d6392865b..0000000000
--- a/tests/pipelines/musicldm/test_musicldm.py
+++ /dev/null
@@ -1,478 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- ClapAudioConfig,
- ClapConfig,
- ClapFeatureExtractor,
- ClapModel,
- ClapTextConfig,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- MusicLDMPipeline,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = MusicLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(32, 64),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=32,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=1,
- out_channels=1,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_branch_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=16,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- )
- audio_branch_config = ClapAudioConfig(
- spec_size=64,
- window_size=4,
- num_mel_bins=64,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- depths=[2, 2],
- num_attention_heads=[2, 2],
- num_hidden_layers=2,
- hidden_size=192,
- patch_size=2,
- patch_stride=2,
- patch_embed_input_channels=4,
- )
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=32
- )
- text_encoder = ClapModel(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
- feature_extractor = ClapFeatureExtractor.from_pretrained(
- "hf-internal-testing/tiny-random-ClapModel", hop_length=7900
- )
-
- torch.manual_seed(0)
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_musicldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0020, -0.0035, -0.0019, -0.0037, -0.0020, -0.0038, -0.0019]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = musicldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = musicldm_pipe.text_encoder.get_text_features(text_inputs)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = musicldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = musicldm_pipe.text_encoder.get_text_features(
- text_inputs,
- )
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = musicldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0019, -0.0035, -0.0018, -0.0037, -0.0021, -0.0038, -0.0018]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = musicldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = musicldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = musicldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = musicldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_musicldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = musicldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = musicldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_musicldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = musicldm_pipe.vocoder.config
- config.model_in_dim *= 2
- musicldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # The method component.dtype returns the dtype of the first parameter registered in the model, not the
- # dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
-
- # Without the logit scale parameters, everything is float32
- model_dtypes.pop("text_encoder")
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # the CLAP sub-models are float32
- model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # Once we send to fp16, all params are in half-precision, including the logit scale
- pipe.to(dtype=torch.float16)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
-
-
-@nightly
-@require_torch_accelerator
-class MusicLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_musicldm(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[8680:8690]
- expected_slice = np.array(
- [-0.1042, -0.1068, -0.1235, -0.1387, -0.1428, -0.136, -0.1213, -0.1097, -0.0967, -0.0945]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
-
- def test_musicldm_lms(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe.scheduler = LMSDiscreteScheduler.from_config(musicldm_pipe.scheduler.config)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[58020:58030]
- expected_slice = np.array([0.3592, 0.3477, 0.4084, 0.4665, 0.5048, 0.5891, 0.6461, 0.5579, 0.4595, 0.4403])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
diff --git a/tests/pipelines/paint_by_example/__init__.py b/tests/pipelines/paint_by_example/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py
deleted file mode 100644
index f122c7411d..0000000000
--- a/tests/pipelines/paint_by_example/test_paint_by_example.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPImageProcessor, CLIPVisionConfig
-
-from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = PaintByExamplePipeline
- params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
- batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=9,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = PNDMScheduler(skip_prk_steps=True)
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- image_size=32,
- patch_size=4,
- )
- image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "image_encoder": image_encoder,
- "safety_checker": None,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def convert_to_pt(self, image):
- image = np.array(image.convert("RGB"))
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
- return image
-
- def get_dummy_inputs(self, device="cpu", seed=0):
- # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
- mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
- example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "example_image": example_image,
- "image": init_image,
- "mask_image": mask_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_paint_by_example_inpaint(self):
- components = self.get_dummy_components()
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**components)
- pipe = pipe.to("cpu")
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs()
- output = pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4686, 0.5687, 0.4007, 0.5218, 0.5741, 0.4482, 0.4940, 0.4629, 0.4503])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_paint_by_example_image_tensor(self):
- device = "cpu"
- inputs = self.get_dummy_inputs()
- inputs.pop("mask_image")
- image = self.convert_to_pt(inputs.pop("image"))
- mask_image = image.clamp(0, 1) / 2
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**self.get_dummy_components())
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(image=image, mask_image=mask_image[:, 0], **inputs)
- out_1 = output.images
-
- image = image.cpu().permute(0, 2, 3, 1)[0]
- mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]
-
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")
-
- output = pipe(**self.get_dummy_inputs())
- out_2 = output.images
-
- assert out_1.shape == (1, 64, 64, 3)
- assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_accelerator
-class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_paint_by_example(self):
- # make sure here that pndm scheduler skips prk
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/dog_in_bucket.png"
- )
- mask_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/mask.png"
- )
- example_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/panda.jpg"
- )
-
- pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(321)
- output = pipe(
- image=init_image,
- mask_image=mask_image,
- example_image=example_image,
- generator=generator,
- guidance_scale=5.0,
- num_inference_steps=50,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/pia/__init__.py b/tests/pipelines/pia/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py
deleted file mode 100644
index 1156bf32da..0000000000
--- a/tests/pipelines/pia/test_pia.py
+++ /dev/null
@@ -1,448 +0,0 @@
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-import diffusers
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DPMSolverMultistepScheduler,
- LCMScheduler,
- MotionAdapter,
- PIAPipeline,
- StableDiffusionPipeline,
- UNet2DConditionModel,
- UNetMotionModel,
-)
-from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import floats_tensor, require_accelerator, torch_device
-
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = PIAPipeline
- params = frozenset(
- [
- "prompt",
- "height",
- "width",
- "guidance_scale",
- "negative_prompt",
- "prompt_embeds",
- "negative_prompt_embeds",
- "cross_attention_kwargs",
- ]
- )
- batch_params = frozenset(["prompt", "image", "generator"])
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback_on_step_end",
- "callback_on_step_end_tensor_inputs",
- ]
- )
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- cross_attention_dim = 8
- block_out_channels = (8, 8)
-
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=block_out_channels,
- layers_per_block=2,
- sample_size=8,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="linear",
- clip_sample=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=block_out_channels,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=cross_attention_dim,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- torch.manual_seed(0)
- motion_adapter = MotionAdapter(
- block_out_channels=block_out_channels,
- motion_layers_per_block=2,
- motion_norm_num_groups=2,
- motion_num_attention_heads=4,
- conv_in_channels=9,
- )
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "motion_adapter": motion_adapter,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- image = floats_tensor((1, 3, 8, 8), rng=random.Random(seed)).to(device)
- inputs = {
- "image": image,
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 7.5,
- "output_type": "pt",
- }
- return inputs
-
- def test_from_pipe_consistent_config(self):
- assert self.original_pipeline_class == StableDiffusionPipeline
- original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
- original_kwargs = {"requires_safety_checker": False}
-
- # create original_pipeline_class(sd)
- pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
-
- # original_pipeline_class(sd) -> pipeline_class
- pipe_components = self.get_dummy_components()
- pipe_additional_components = {}
- for name, component in pipe_components.items():
- if name not in pipe_original.components:
- pipe_additional_components[name] = component
-
- pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
-
- # pipeline_class -> original_pipeline_class(sd)
- original_pipe_additional_components = {}
- for name, component in pipe_original.components.items():
- if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
- original_pipe_additional_components[name] = component
-
- pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
-
- # compare the config
- original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
- original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
- assert original_config_2 == original_config
-
- def test_motion_unet_loading(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
-
- assert isinstance(pipe.unet, UNetMotionModel)
-
- def test_ip_adapter(self):
- expected_pipe_slice = None
-
- if torch_device == "cpu":
- expected_pipe_slice = np.array(
- [
- 0.5475,
- 0.5769,
- 0.4873,
- 0.5064,
- 0.4445,
- 0.5876,
- 0.5453,
- 0.4102,
- 0.5247,
- 0.5370,
- 0.3406,
- 0.4322,
- 0.3991,
- 0.3756,
- 0.5438,
- 0.4780,
- 0.5087,
- 0.5248,
- 0.6243,
- 0.5506,
- 0.3491,
- 0.5440,
- 0.6111,
- 0.5122,
- 0.5326,
- 0.5180,
- 0.5538,
- ]
- )
- return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5476, 0.4092, 0.5289, 0.4755, 0.5092, 0.5186, 0.5403, 0.5287, 0.5467])
- return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- @unittest.skip("Attention slicing is not enabled in this pipeline")
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_inference_batch_single_identical(
- self,
- batch_size=2,
- expected_max_diff=1e-4,
- additional_params_copy_to_batched_inputs=["num_inference_steps"],
- ):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for components in pipe.components.values():
- if hasattr(components, "set_default_attn_processor"):
- components.set_default_attn_processor()
-
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
- # Reset generator in case it is has been used in self.get_dummy_inputs
- inputs["generator"] = self.get_generator(0)
-
- logger = logging.get_logger(pipe.__module__)
- logger.setLevel(level=diffusers.logging.FATAL)
-
- # batchify inputs
- batched_inputs = {}
- batched_inputs.update(inputs)
-
- for name in self.batch_params:
- if name not in inputs:
- continue
-
- value = inputs[name]
- if name == "prompt":
- len_prompt = len(value)
- batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
- batched_inputs[name][-1] = 100 * "very long"
-
- else:
- batched_inputs[name] = batch_size * [value]
-
- if "generator" in inputs:
- batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
-
- if "batch_size" in inputs:
- batched_inputs["batch_size"] = batch_size
-
- for arg in additional_params_copy_to_batched_inputs:
- batched_inputs[arg] = inputs[arg]
-
- output = pipe(**inputs)
- output_batch = pipe(**batched_inputs)
-
- assert output_batch[0].shape[0] == batch_size
-
- max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
- assert max_diff < expected_max_diff
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_prompt_embeds(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs.pop("prompt")
- inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
- pipe(**inputs)
-
- def test_free_init(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- pipe.enable_free_init(
- num_iters=2,
- use_fast_sampling=True,
- method="butterworth",
- order=4,
- spatial_stop_frequency=0.25,
- temporal_stop_frequency=0.25,
- )
- inputs_enable_free_init = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
-
- pipe.disable_free_init()
- inputs_disable_free_init = self.get_dummy_inputs(torch_device)
- frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
-
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
- max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
- self.assertGreater(
- sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
- )
- self.assertLess(
- max_diff_disabled,
- 1e-4,
- "Disabling of FreeInit should lead to results similar to the default pipeline results",
- )
-
- def test_free_init_with_schedulers(self):
- components = self.get_dummy_components()
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- schedulers_to_test = [
- DPMSolverMultistepScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- algorithm_type="dpmsolver++",
- steps_offset=1,
- clip_sample=False,
- ),
- LCMScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- steps_offset=1,
- clip_sample=False,
- ),
- ]
- components.pop("scheduler")
-
- for scheduler in schedulers_to_test:
- components["scheduler"] = scheduler
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
-
- inputs = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs).frames[0]
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
-
- self.assertGreater(
- sum_enabled,
- 1e1,
- "Enabling of FreeInit should lead to results different from the default pipeline results",
- )
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for component in pipe.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_without_offload = pipe(**inputs).frames[0]
- output_without_offload = (
- output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
- )
-
- pipe.enable_xformers_memory_efficient_attention()
- inputs = self.get_dummy_inputs(torch_device)
- output_with_offload = pipe(**inputs).frames[0]
- output_with_offload = (
- output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
- )
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/semantic_stable_diffusion/__init__.py b/tests/pipelines/semantic_stable_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
deleted file mode 100644
index b4d82b0fb2..0000000000
--- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
+++ /dev/null
@@ -1,617 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_semantic_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5753, 0.6114, 0.5001, 0.5034, 0.5470, 0.4729, 0.4971, 0.4867, 0.4867])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_torch_accelerator
- def test_semantic_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_accelerator
-class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_positive_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34673113,
- 0.38492733,
- 0.37597352,
- 0.34086335,
- 0.35650748,
- 0.35579205,
- 0.3384763,
- 0.34340236,
- 0.3573271,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.41887826,
- 0.37728766,
- 0.30138272,
- 0.41416335,
- 0.41664985,
- 0.36283392,
- 0.36191246,
- 0.43364465,
- 0.43001732,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_negative_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "an image of a crowded boulevard, realistic, 4k"
- edit = {
- "editing_prompt": "crowd, crowded, people",
- "reverse_editing_direction": True,
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 8.3,
- "edit_threshold": 0.9,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 9
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.43497998,
- 0.91814065,
- 0.7540739,
- 0.55580205,
- 0.8467265,
- 0.5389691,
- 0.62574506,
- 0.58897763,
- 0.50926757,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.3089719,
- 0.30500144,
- 0.29016042,
- 0.30630964,
- 0.325687,
- 0.29419225,
- 0.2908091,
- 0.28723598,
- 0.27696294,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_multi_cond_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a castle next to a river"
- edit = {
- "editing_prompt": ["boat on a river, boat", "monet, impression, sunrise"],
- "reverse_editing_direction": False,
- "edit_warmup_steps": [15, 18],
- "edit_guidance_scale": 6,
- "edit_threshold": [0.9, 0.8],
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 48
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.75163555,
- 0.76037145,
- 0.61785,
- 0.9189673,
- 0.8627701,
- 0.85189694,
- 0.8512813,
- 0.87012076,
- 0.8312857,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.73553365,
- 0.7537271,
- 0.74341905,
- 0.66480356,
- 0.6472925,
- 0.63039416,
- 0.64812905,
- 0.6749717,
- 0.6517102,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_guidance_fp16(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
- )
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34887695,
- 0.3876953,
- 0.375,
- 0.34423828,
- 0.3581543,
- 0.35717773,
- 0.3383789,
- 0.34570312,
- 0.359375,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.42285156,
- 0.36914062,
- 0.29077148,
- 0.42041016,
- 0.41918945,
- 0.35498047,
- 0.3618164,
- 0.4423828,
- 0.43115234,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
deleted file mode 100644
index 45fc70be23..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- StableDiffusionAttendAndExcitePipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- load_numpy,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-torch.backends.cuda.matmul.allow_tf32 = False
-
-
-@skip_mps
-class StableDiffusionAttendAndExcitePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionAttendAndExcitePipeline
- test_attention_slicing = False
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a cat and a frog",
- "token_indices": [2, 5],
- "generator": generator,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- "max_iter_to_alter": 2,
- "thresholds": {0: 0.7},
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.6391, 0.6290, 0.4860, 0.5134, 0.5550, 0.4577, 0.5033, 0.5023, 0.4538])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice, expected_max_difference=3e-3)
-
- def test_inference(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- self.assertEqual(image.shape, (1, 64, 64, 3))
- expected_slice = np.array(
- [0.63905364, 0.62897307, 0.48599017, 0.5133624, 0.5550048, 0.45769516, 0.50326973, 0.5023139, 0.45384496]
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
-
- def test_inference_batch_consistent(self):
- # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
- self._test_inference_batch_consistent(batch_sizes=[1, 2])
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=7e-4)
-
- def test_pt_np_pil_outputs_equivalent(self):
- super().test_pt_np_pil_outputs_equivalent(expected_max_diff=5e-4)
-
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=5e-4)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=4e-4)
-
- def test_karras_schedulers_shape(self):
- super().test_karras_schedulers_shape(num_inference_steps_for_strength_for_iterations=3)
-
- def test_from_pipe_consistent_forward_pass_cpu_offload(self):
- super().test_from_pipe_consistent_forward_pass_cpu_offload(expected_max_diff=5e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_attend_and_excite_fp16(self):
- generator = torch.manual_seed(51)
-
- pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.to(torch_device)
-
- prompt = "a painting of an elephant with glasses"
- token_indices = [5, 7]
-
- image = pipe(
- prompt=prompt,
- token_indices=token_indices,
- guidance_scale=7.5,
- generator=generator,
- num_inference_steps=5,
- max_iter_to_alter=5,
- output_type="np",
- ).images[0]
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/attend-and-excite/elephant_glasses.npy"
- )
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
- assert max_diff < 5e-1
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
deleted file mode 100644
index 9f8870af7b..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
+++ /dev/null
@@ -1,452 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMInverseScheduler,
- DDIMScheduler,
- DPMSolverMultistepInverseScheduler,
- DPMSolverMultistepScheduler,
- StableDiffusionDiffEditPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class StableDiffusionDiffEditPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
-):
- pipeline_class = StableDiffusionDiffEditPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"}
- image_params = frozenset(
- []
- ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
- image_latents_params = frozenset([])
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- inverse_scheduler = DDIMInverseScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_zero=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "inverse_scheduler": inverse_scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device)
- latents = floats_tensor((1, 2, 4, 16, 16), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a dog and a newt",
- "mask_image": mask,
- "image_latents": latents,
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_mask_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "source_prompt": "a cat and a frog",
- "target_prompt": "a dog and a newt",
- "generator": generator,
- "num_inference_steps": 2,
- "num_maps_per_mask": 2,
- "mask_encode_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_inversion_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "prompt": "a cat and a frog",
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "decode_latents": True,
- "output_type": "np",
- }
- return inputs
-
- def test_save_load_optional_components(self):
- if not hasattr(self.pipeline_class, "_optional_components"):
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- # set all optional components to None and update pipeline config accordingly
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
- pipe.register_modules(**dict.fromkeys(pipe._optional_components))
-
- inputs = self.get_dummy_inputs(torch_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(output - output_loaded).max()
- self.assertLess(max_diff, 1e-4)
-
- def test_mask(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_mask_inputs(device)
- mask = pipe.generate_mask(**inputs)
- mask_slice = mask[0, -3:, -3:]
-
- self.assertEqual(mask.shape, (1, 16, 16))
- expected_slice = np.array([0] * 9)
- max_diff = np.abs(mask_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
- self.assertEqual(mask[0, -3, -4], 0)
-
- def test_inversion(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5160, 0.5115, 0.5060, 0.5456, 0.4704, 0.5060, 0.5019, 0.4405, 0.4726],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=5e-3)
-
- def test_inversion_dpm(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"}
- components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args)
- components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args)
-
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5305, 0.4673, 0.5314, 0.5308, 0.4886, 0.5279, 0.5142, 0.4724, 0.4892],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
- raw_image = raw_image.convert("RGB").resize((256, 256))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_full(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.scheduler.clip_sample = True
-
- pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=5,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=5,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((256, 256))
- )
- / 255
- )
-
- assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 2e-1
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionDiffEditPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
-
- raw_image = raw_image.convert("RGB").resize((768, 768))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_dpm(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
- pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=25,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=25,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((768, 768))
- )
- / 255
- )
- assert np.abs((expected_image - image).max()) < 5e-1
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 577ac4ebdd..2179ec8e22 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -233,7 +232,6 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index f5b5e63a81..7f913cb63d 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -3,7 +3,6 @@ import random
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -168,7 +167,6 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/stable_diffusion_gligen/__init__.py b/tests/pipelines/stable_diffusion_gligen/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
deleted file mode 100644
index 5d56f16803..0000000000
--- a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_gligen_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip("Test not supported as tokenizer is used for parsing bounding boxes.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/__init__.py b/tests/pipelines/stable_diffusion_gligen_text_image/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
deleted file mode 100644
index 3f092e02dd..0000000000
--- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENTextImagePipeline,
- UNet2DConditionModel,
-)
-from diffusers.pipelines.stable_diffusion import CLIPImageProjection
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenTextImagePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENTextImagePipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated-text-image",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- image_encoder_config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- )
- image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
-
- image_project = CLIPImageProjection(hidden_size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": image_encoder,
- "image_project": image_project,
- "processor": processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- gligen_images = load_image(
- "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png"
- )
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_images": [gligen_images],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5052, 0.5546, 0.4567, 0.4770, 0.5195, 0.4085, 0.5026, 0.4909, 0.4495])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_stable_diffusion_gligen_text_image_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip(
- "Test not supported because of the use of `text_encoder` in `get_cross_attention_kwargs_with_grounded()`."
- )
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/__init__.py b/tests/pipelines/stable_diffusion_k_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
deleted file mode 100644
index dc7e62078a..0000000000
--- a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionKDiffusionPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_1(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0447, 0.0492, 0.0468, 0.0408, 0.0383, 0.0408, 0.0354, 0.0380, 0.0339])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_2(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=15,
- output_type="np",
- use_karras_sigmas=True,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array(
- [0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
-
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_ldm3d/__init__.py b/tests/pipelines/stable_diffusion_ldm3d/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py b/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
deleted file mode 100644
index 936e22b470..0000000000
--- a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- PNDMScheduler,
- StableDiffusionLDM3DPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-
-
-enable_full_determinism()
-
-
-class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase):
- pipeline_class = StableDiffusionLDM3DPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=6,
- out_channels=6,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- image_slice_rgb = rgb[0, -3:, -3:, -1]
- image_slice_depth = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262]
- )
- expected_slice_depth = np.array([103.46727, 85.812004, 87.849236])
-
- assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2
-
- def test_stable_diffusion_prompt_embeds(self):
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_1, depth_slice_1 = output.rgb, output.depth
- rgb_slice_1 = rgb_slice_1[0, -3:, -3:, -1]
- depth_slice_1 = depth_slice_1[0, -3:, -1]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = ldm3d_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=ldm3d_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = ldm3d_pipe.text_encoder(text_inputs)[0]
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_2, depth_slice_2 = output.rgb, output.depth
- rgb_slice_2 = rgb_slice_2[0, -3:, -3:, -1]
- depth_slice_2 = depth_slice_2[0, -3:, -1]
-
- assert np.abs(rgb_slice_1.flatten() - rgb_slice_2.flatten()).max() < 1e-4
- assert np.abs(depth_slice_1.flatten() - depth_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = ldm3d_pipe(**inputs, negative_prompt=negative_prompt)
-
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1]
- depth_slice = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37044, 0.71811503, 0.7223251, 0.48603675, 0.5638391, 0.6364948, 0.42833704, 0.4901315, 0.47926217]
- )
- expected_slice_depth = np.array([107.84738, 84.62802, 89.962135])
- assert np.abs(rgb_slice.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(depth_slice.flatten() - expected_slice_depth).max() < 1e-2
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionLDM3DPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d_stable_diffusion(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1].flatten()
- depth_slice = rgb[0, -3:, -1].flatten()
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512)
-
- expected_slice_rgb = np.array(
- [0.53805465, 0.56707305, 0.5486515, 0.57012236, 0.5814511, 0.56253487, 0.54843014, 0.55092263, 0.6459706]
- )
- expected_slice_depth = np.array(
- [0.9263781, 0.6678672, 0.5486515, 0.92202145, 0.67831135, 0.56253487, 0.9241694, 0.7551478, 0.6459706]
- )
- assert np.abs(rgb_slice - expected_slice_rgb).max() < 3e-3
- assert np.abs(depth_slice - expected_slice_depth).max() < 3e-3
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 50,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.495586
- expected_rgb_std = 0.33795515
- expected_depth_mean = 112.48518
- expected_depth_std = 98.489746
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
-
- def test_ldm3d_v2(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.4194127
- expected_rgb_std = 0.35375586
- expected_depth_mean = 0.5638502
- expected_depth_std = 0.34686103
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512, 1)
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
diff --git a/tests/pipelines/stable_diffusion_panorama/__init__.py b/tests/pipelines/stable_diffusion_panorama/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
deleted file mode 100644
index 61f91cae2b..0000000000
--- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
+++ /dev/null
@@ -1,444 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- StableDiffusionPanoramaPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- backend_max_memory_allocated,
- backend_reset_max_memory_allocated,
- backend_reset_peak_memory_stats,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class StableDiffusionPanoramaPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionPanoramaPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler()
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- # Setting height and width to None to prevent OOMs on CPU.
- "height": None,
- "width": None,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6186, 0.5374, 0.4915, 0.4135, 0.4114, 0.4563, 0.5128, 0.4977, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_circular_padding_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs, circular_padding=True).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # override to speed the overall test timing up.
- def test_inference_batch_consistent(self):
- super().test_inference_batch_consistent(batch_sizes=[1, 2])
-
- # override to speed the overall test timing up.
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=5.0e-3)
-
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1e-1)
-
- def test_stable_diffusion_panorama_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = sd_pipe(**inputs, negative_prompt=negative_prompt)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch_circular_padding(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_euler(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = EulerAncestralDiscreteScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, seed=0):
- generator = torch.manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default(self):
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- 0.36968392,
- 0.27025372,
- 0.32446766,
- 0.28379387,
- 0.36363274,
- 0.30733347,
- 0.27100027,
- 0.27054125,
- 0.25536096,
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_k_lms(self):
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-base", safety_checker=None
- )
- pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
- pipe.unet.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- [
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- ]
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_intermediate_state(self):
- number_of_steps = 0
-
- def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
- callback_fn.has_been_called = True
- nonlocal number_of_steps
- number_of_steps += 1
- if step == 1:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18681869,
- 0.33907816,
- 0.5361276,
- 0.14432865,
- -0.02856611,
- -0.73941123,
- 0.23397987,
- 0.47322682,
- -0.37823164,
- ]
- )
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
- elif step == 2:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18539645,
- 0.33987248,
- 0.5378559,
- 0.14437142,
- -0.02455261,
- -0.7338317,
- 0.23990755,
- 0.47356272,
- -0.3786505,
- ]
- )
-
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
-
- callback_fn.has_been_called = False
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- pipe(**inputs, callback=callback_fn, callback_steps=1)
- assert callback_fn.has_been_called
- assert number_of_steps == 3
-
- def test_stable_diffusion_panorama_pipeline_with_sequential_cpu_offloading(self):
- backend_empty_cache(torch_device)
- backend_reset_max_memory_allocated(torch_device)
- backend_reset_peak_memory_stats(torch_device)
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
-
- inputs = self.get_inputs()
- _ = pipe(**inputs)
-
- mem_bytes = backend_max_memory_allocated(torch_device)
- # make sure that less than 5.2 GB is allocated
- assert mem_bytes < 5.5 * 10**9
diff --git a/tests/pipelines/stable_diffusion_safe/__init__.py b/tests/pipelines/stable_diffusion_safe/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
deleted file mode 100644
index 5d81cff3e0..0000000000
--- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
+++ /dev/null
@@ -1,497 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
- Expectations,
- backend_empty_cache,
- floats_tensor,
- nightly,
- require_accelerator,
- require_torch_accelerator,
- torch_device,
-)
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_safe_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5125, 0.5716, 0.4828, 0.5060, 0.5650, 0.4768, 0.5185, 0.4895, 0.4993])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_accelerator
- def test_stable_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_accelerator
-class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_harm_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
- " coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
- " anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
- " children from bahnhof zoo, detailed "
- )
- seed = 4003660346
- guidance_scale = 7
-
- # without safety guidance (sld_guidance_scale = 0)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
- ("cuda", 7): [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176],
- ("cuda", 8): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # without safety guidance (strong configuration)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
-
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
- ("cuda", 7): [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719],
- ("cuda", 8): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
- seed = 2734971755
- guidance_scale = 7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.3244, 0.3355, 0.3260, 0.3123, 0.3246, 0.3426, 0.3109, 0.3471, 0.4001],
- ("cuda", 7): [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297],
- ("cuda", 8): [0.3605, 0.3684, 0.3712, 0.3624, 0.3675, 0.3726, 0.3494, 0.3748, 0.4044],
- }
- )
- expected_slice = expected_slices.get_expectation()
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.6178, 0.6260, 0.6194, 0.6435, 0.6265, 0.6461, 0.6567, 0.6576, 0.6444],
- ("cuda", 7): [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443],
- ("cuda", 8): [0.5892, 0.5959, 0.5914, 0.6123, 0.5982, 0.6141, 0.6180, 0.6262, 0.6171],
- }
- )
-
- print(f"image_slice: {image_slice.flatten()}")
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safetychecker_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
- " leyendecker"
- )
- seed = 1044355234
- guidance_scale = 12
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
- ("cuda", 7): np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]),
- ("cuda", 8): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_sag/__init__.py b/tests/pipelines/stable_diffusion_sag/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
deleted file mode 100644
index 1d18403322..0000000000
--- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DEISMultistepScheduler,
- DPMSolverMultistepScheduler,
- EulerDiscreteScheduler,
- StableDiffusionSAGPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionSAGPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionSAGPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=8,
- norm_num_groups=1,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- norm_num_groups=1,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- num_hidden_layers=2,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": ".",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 1.0,
- "sag_scale": 1.0,
- "output_type": "np",
- }
- return inputs
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
- @unittest.skip("Not necessary to test here.")
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
- def test_pipeline_different_schedulers(self):
- pipeline = self.pipeline_class(**self.get_dummy_components())
- inputs = self.get_dummy_inputs("cpu")
-
- expected_image_size = (16, 16, 3)
- for scheduler_cls in [DDIMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler]:
- pipeline.scheduler = scheduler_cls.from_config(pipeline.scheduler.config)
- image = pipeline(**inputs).images[0]
-
- shape = image.shape
- assert shape == expected_image_size
-
- pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-
- with self.assertRaises(ValueError):
- # Karras schedulers are not supported
- image = pipeline(**inputs).images[0]
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_1(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1568, 0.1738, 0.1695, 0.1693, 0.1507, 0.1705, 0.1547, 0.1751, 0.1949])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2_non_square(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt],
- width=768,
- height=512,
- generator=generator,
- guidance_scale=7.5,
- sag_scale=1.0,
- num_inference_steps=20,
- output_type="np",
- )
-
- image = output.images
-
- assert image.shape == (1, 512, 768, 3)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
deleted file mode 100644
index ae131d1d4f..0000000000
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionXLKDiffusionPipeline
-from diffusers.utils.testing_utils import (
- Expectations,
- backend_empty_cache,
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
- dtype = torch.float16
-
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_xl(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=9.0,
- num_inference_steps=2,
- height=512,
- width=512,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.5420, 0.5038, 0.2439, 0.5371, 0.4660, 0.1906, 0.5221, 0.4290, 0.2566])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=2,
- output_type="np",
- use_karras_sigmas=True,
- height=512,
- width=512,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slices = Expectations(
- {
- ("xpu", 3): np.array(
- [
- 0.6128,
- 0.6108,
- 0.6109,
- 0.5997,
- 0.5988,
- 0.5948,
- 0.5903,
- 0.597,
- 0.5973,
- ]
- ),
- ("cuda", 7): np.array(
- [
- 0.6418,
- 0.6424,
- 0.6462,
- 0.6271,
- 0.6314,
- 0.6295,
- 0.6249,
- 0.6339,
- 0.6335,
- ]
- ),
- }
- )
-
- expected_slice = expected_slices.get_expectation()
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 69dd79bb56..13c25ccaa4 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -33,6 +33,7 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
+from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
@@ -1378,7 +1379,6 @@ class PipelineTesterMixin:
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
-
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1386,17 +1386,20 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
-
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
+
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
- assert max_diff < 1e-2
+ assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
@@ -2646,7 +2649,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
- output = run_forward(pipe).flatten().flatten()
+ output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2753,6 +2756,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
+# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
+# of the box once there is better cache support/implementation
+class FirstBlockCacheTesterMixin:
+ # threshold is intentionally set higher than usual values since we're testing with random unconverged models
+ # that will not satisfy the expected properties of the denoiser for caching to be effective
+ first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
+
+ def test_first_block_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ return pipe(**inputs)[0]
+
+ # Run inference without FirstBlockCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache enabled
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.first_block_cache_config)
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
+ "FirstBlockCache outputs should not differ much."
+ )
+ assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
+
+
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
diff --git a/tests/pipelines/text_to_video_synthesis/__init__.py b/tests/pipelines/text_to_video_synthesis/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
deleted file mode 100644
index 445f876985..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoSDPipeline, UNet3DConditionModel
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_numpy,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoSDPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(8, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = TextToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.8093, 0.2751, 0.6976, 0.5927, 0.4616, 0.4336, 0.5094, 0.5683, 0.4796])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skipIf(torch_device != "cuda", reason="Feature isn't heavily used. Test in CUDA environment only.")
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@skip_mps
-@require_torch_accelerator
-class TextToVideoSDPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_two_step_model(self):
- expected_video = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
- )
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4
-
- def test_two_step_model_with_freeu(self):
- expected_video = []
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- video = video_frames[0, 0, -3:, -3:, -1].flatten()
-
- expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]
-
- assert np.abs(expected_video - video).mean() < 5e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
deleted file mode 100644
index 8c29b27416..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import torch
-
-from diffusers import DDIMScheduler, TextToVideoZeroPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- load_pt,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..test_pipelines_common import assert_mean_pixel_difference
-
-
-@nightly
-@require_torch_accelerator
-class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_full_model(self):
- model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(torch_device)
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- prompt = "A bear is playing a guitar on Times Square"
- result = pipe(prompt=prompt, generator=generator).images
-
- expected_result = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt",
- weights_only=False,
- )
-
- assert_mean_pixel_difference(result, expected_result)
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
deleted file mode 100644
index da60435d0d..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import inspect
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_accelerate_version_greater,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoZeroSDXLPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- generator_device = "cpu"
-
- def get_dummy_components(self, seed=0):
- torch.manual_seed(seed)
- unet = UNet2DConditionModel(
- block_out_channels=(2, 4),
- layers_per_block=2,
- sample_size=2,
- norm_num_groups=2,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- )
- scheduler = DDIMScheduler(
- num_train_timesteps=1000,
- beta_start=0.0001,
- beta_end=0.02,
- beta_schedule="linear",
- trained_betas=None,
- clip_sample=True,
- set_alpha_to_one=True,
- steps_offset=0,
- prediction_type="epsilon",
- thresholding=False,
- dynamic_thresholding_ratio=0.995,
- clip_sample_range=1.0,
- sample_max_value=1.0,
- timestep_spacing="leading",
- rescale_betas_zero_snr=False,
- )
- torch.manual_seed(seed)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(seed)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A panda dancing in Antarctica",
- "generator": generator,
- "num_inference_steps": 5,
- "t0": 1,
- "t1": 3,
- "height": 64,
- "width": 64,
- "video_length": 3,
- "output_type": "np",
- }
- return inputs
-
- def get_generator(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- return generator
-
- def test_text_to_video_zero_sdxl(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- result = pipe(**inputs).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array(
- [0.6008109, 0.73051643, 0.51778656, 0.55817354, 0.45222935, 0.45998418, 0.57017255, 0.54874814, 0.47078788]
- )
- expected_slice2 = np.array(
- [0.6011751, 0.47420046, 0.41660714, 0.6472957, 0.41261768, 0.5438129, 0.7401535, 0.6756011, 0.53652245]
- )
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_cfg(self):
- sig = inspect.signature(self.pipeline_class.__call__)
- if "guidance_scale" not in sig.parameters:
- return
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
-
- inputs["guidance_scale"] = 1.0
- out_no_cfg = pipe(**inputs)[0]
-
- inputs["guidance_scale"] = 7.5
- out_cfg = pipe(**inputs)[0]
-
- assert out_cfg.shape == out_no_cfg.shape
-
- def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(self.generator_device))[0]
- output_tuple = pipe(**self.get_dummy_inputs(self.generator_device), return_dict=False)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
- self.assertLess(max_diff, expected_max_difference)
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_torch_accelerator
- def test_float16_inference(self, expected_max_diff=5e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- components = self.get_dummy_components()
- pipe_fp16 = self.pipeline_class(**components)
- pipe_fp16.to(torch_device, torch.float16)
- pipe_fp16.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- # # Reset generator in case it is used inside dummy inputs
- if "generator" in inputs:
- inputs["generator"] = self.get_generator(self.generator_device)
-
- output = pipe(**inputs)[0]
-
- fp16_inputs = self.get_dummy_inputs(self.generator_device)
- # Reset generator in case it is used inside dummy inputs
- if "generator" in fp16_inputs:
- fp16_inputs["generator"] = self.get_generator(self.generator_device)
-
- output_fp16 = pipe_fp16(**fp16_inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
- self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
-
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_inference_batch_single_identical(self):
- pass
-
- @require_torch_accelerator
- @require_accelerate_version_greater("0.17.0")
- def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_without_offload = pipe(**inputs)[0]
-
- pipe.enable_model_cpu_offload(device=torch_device)
- inputs = self.get_dummy_inputs(self.generator_device)
- output_with_offload = pipe(**inputs)[0]
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_pipeline_call_signature(self):
- pass
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_torch_accelerator
- def test_save_load_float16(self, expected_max_diff=1e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
-
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for name, component in pipe_loaded.components.items():
- if hasattr(component, "dtype"):
- self.assertTrue(
- component.dtype == torch.float16,
- f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
- )
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_loaded = pipe_loaded(**inputs)[0]
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(
- max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
- )
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_local(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_optional_components(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_sequential_cpu_offload_forward_pass(self):
- pass
-
- @require_torch_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
-
-@nightly
-@require_torch_accelerator
-class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_full_model(self):
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
- pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
- model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
- )
- pipe.enable_model_cpu_offload()
- pipe.enable_vae_slicing()
-
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- prompt = "A panda dancing in Antarctica"
- result = pipe(prompt=prompt, generator=generator).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array([0.57, 0.57, 0.57, 0.57, 0.57, 0.56, 0.55, 0.56, 0.56])
- expected_slice2 = np.array([0.54, 0.53, 0.53, 0.53, 0.53, 0.52, 0.53, 0.53, 0.53])
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
deleted file mode 100644
index 2efef3d640..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- UNet3DConditionModel,
- VideoToVideoSDPipeline,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- is_flaky,
- nightly,
- numpy_cosine_similarity_distance,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = VideoToVideoSDPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"video"}) - {"image", "width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"video"}) - {"image"}
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
- test_attention_slicing = False
-
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=32,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=True,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[
- 8,
- ],
- in_channels=3,
- out_channels=3,
- down_block_types=[
- "DownEncoderBlock2D",
- ],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- # 3 frames
- video = floats_tensor((1, 3, 3, 32, 32), rng=random.Random(seed)).to(device)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "video": video,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = VideoToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @is_flaky()
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.001)
-
- @is_flaky()
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent()
-
- @is_flaky()
- def test_save_load_local(self):
- super().test_save_load_local()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=5e-3)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@skip_mps
-class VideoToVideoSDPipelineSlowTests(unittest.TestCase):
- def test_two_step_model(self):
- pipe = VideoToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
- # 10 frames
- generator = torch.Generator(device="cpu").manual_seed(0)
- video = torch.randn((1, 10, 3, 320, 576), generator=generator)
-
- prompt = "Spiderman is surfing"
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames
-
- expected_array = np.array(
- [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
- )
- output_array = video_frames[0, 0, :3, :3, 0].flatten()
- assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3
diff --git a/tests/pipelines/unclip/__init__.py b/tests/pipelines/unclip/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py
deleted file mode 100644
index 4a970a4f6f..0000000000
--- a/tests/pipelines/unclip/test_unclip.py
+++ /dev/null
@@ -1,523 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- backend_max_memory_allocated,
- backend_reset_max_memory_allocated,
- backend_reset_peak_memory_stats,
- enable_full_determinism,
- load_numpy,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPPipeline
- params = TEXT_TO_IMAGE_PARAMS - {
- "negative_prompt",
- "height",
- "width",
- "negative_prompt_embeds",
- "guidance_scale",
- "prompt_embeds",
- "cross_attention_kwargs",
- }
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- required_optional_params = [
- "generator",
- "return_dict",
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "num_attention_heads": 2,
- "attention_head_dim": 12,
- "embedding_dim": self.text_embedder_hidden_size,
- "num_layers": 1,
- }
-
- model = PriorTransformer(**model_kwargs)
- return model
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- prior_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="sample",
- num_train_timesteps=1000,
- clip_sample_range=5.0,
- )
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- components = {
- "prior": prior,
- "decoder": decoder,
- "text_proj": text_proj,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "prior_scheduler": prior_scheduler,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_num_inference_steps": 2,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_unclip(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(
- **self.get_dummy_inputs(device),
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9988,
- 0.0028,
- 0.9997,
- 0.9984,
- 0.9965,
- 0.0029,
- 0.9986,
- 0.0025,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_text_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- prior = components["prior"]
- decoder = components["decoder"]
- super_res_first = components["super_res_first"]
- tokenizer = components["tokenizer"]
- text_encoder = components["text_encoder"]
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = prior.dtype
- batch_size = 1
-
- shape = (batch_size, prior.config.embedding_dim)
- prior_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
- shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
- generator = torch.Generator(device=device).manual_seed(0)
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- super_res_first.config.in_channels // 2,
- super_res_first.config.sample_size,
- super_res_first.config.sample_size,
- )
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "this is a prompt example"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = pipe(
- [prompt],
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- output_type="np",
- )
- image = output.images
-
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- return_tensors="pt",
- )
- text_model_output = text_encoder(text_inputs.input_ids)
- text_attention_mask = text_inputs.attention_mask
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_text = pipe(
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- text_model_output=text_model_output,
- text_attention_mask=text_attention_mask,
- output_type="np",
- )[0]
-
- # make sure passing text embeddings manually is identical
- assert np.abs(image - image_from_text).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference, expected_max_diff=0.01)
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=9.8e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=5e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large differences in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_karlo_cpu_fp32(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_cpu.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipeline(
- "horse",
- num_images_per_prompt=1,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
- assert np.abs(expected_image - image).max() < 1e-1
-
-
-@nightly
-@require_torch_accelerator
-class UnCLIPPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_karlo(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_fp16.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- "horse",
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image)
-
- def test_unclip_pipeline_with_sequential_cpu_offloading(self):
- backend_empty_cache(torch_device)
- backend_reset_max_memory_allocated(torch_device)
- backend_reset_peak_memory_stats(torch_device)
-
- pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
- pipe.enable_sequential_cpu_offload()
-
- _ = pipe(
- "horse",
- num_images_per_prompt=1,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- output_type="np",
- )
-
- mem_bytes = backend_max_memory_allocated(torch_device)
- # make sure that less than 7 GB is allocated
- assert mem_bytes < 7 * 10**9
diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py
deleted file mode 100644
index 15733513a5..0000000000
--- a/tests/pipelines/unclip/test_unclip_image_variation.py
+++ /dev/null
@@ -1,540 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- DiffusionPipeline,
- UnCLIPImageVariationPipeline,
- UnCLIPScheduler,
- UNet2DConditionModel,
- UNet2DModel,
-)
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- load_numpy,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPImageVariationPipeline
- params = IMAGE_VARIATION_PARAMS - {"height", "width", "guidance_scale"}
- batch_params = IMAGE_VARIATION_BATCH_PARAMS
-
- required_optional_params = [
- "generator",
- "return_dict",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
- supports_dduf = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_image_encoder(self):
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- num_hidden_layers=5,
- num_attention_heads=4,
- image_size=32,
- intermediate_size=37,
- patch_size=1,
- )
- return CLIPVisionModelWithProjection(config)
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- image_encoder = self.dummy_image_encoder
-
- return {
- "decoder": decoder,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_proj": text_proj,
- "feature_extractor": feature_extractor,
- "image_encoder": image_encoder,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- def get_dummy_inputs(self, device, seed=0, pil_image=True):
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- if pil_image:
- input_image = input_image * 0.5 + 0.5
- input_image = input_image.clamp(0, 1)
- input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy()
- input_image = DiffusionPipeline.numpy_to_pil(input_image)[0]
-
- return {
- "image": input_image,
- "generator": generator,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
-
- def test_unclip_image_variation_input_tensor(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.0002,
- 0.9997,
- 0.9997,
- 0.9969,
- 0.0023,
- 0.9997,
- 0.9969,
- 0.9970,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_image(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_list_images(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- pipeline_inputs["image"] = [
- pipeline_inputs["image"],
- pipeline_inputs["image"],
- ]
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- tuple_pipeline_inputs["image"] = [
- tuple_pipeline_inputs["image"],
- tuple_pipeline_inputs["image"],
- ]
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (2, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9989,
- 0.0008,
- 0.0021,
- 0.9960,
- 0.0018,
- 0.0014,
- 0.0002,
- 0.9933,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_image_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = pipe.decoder.dtype
- batch_size = 1
-
- shape = (
- batch_size,
- pipe.decoder.config.in_channels,
- pipe.decoder.config.sample_size,
- pipe.decoder.config.sample_size,
- )
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- pipe.super_res_first.config.in_channels // 2,
- pipe.super_res_first.config.sample_size,
- pipe.super_res_first.config.sample_size,
- )
- generator = torch.Generator(device=device).manual_seed(0)
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- img_out_1 = pipe(
- **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
- ).images
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
- # Don't pass image, instead pass embedding
- image = pipeline_inputs.pop("image")
- image_embeddings = pipe.image_encoder(image).image_embeds
-
- img_out_2 = pipe(
- **pipeline_inputs,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- image_embeddings=image_embeddings,
- ).images
-
- # make sure passing text embeddings manually is identical
- assert np.abs(img_out_1 - img_out_2).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
- expected_max_diff = 1e-2
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
- )
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @unittest.skip("UnCLIP produces very large differences. Test is not useful.")
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @unittest.skip("UnCLIP produces very large difference. Test is not useful.")
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=4e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large difference in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-@require_torch_accelerator
-class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_image_variation_karlo(self):
- input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png"
- )
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
- )
-
- pipeline = UnCLIPImageVariationPipeline.from_pretrained(
- "kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16
- )
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- input_image,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image, 15)
diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/unidiffuser/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py
deleted file mode 100644
index dccb1a8500..0000000000
--- a/tests/pipelines/unidiffuser/test_unidiffuser.py
+++ /dev/null
@@ -1,764 +0,0 @@
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
- GPT2Tokenizer,
-)
-
-from diffusers import (
- AutoencoderKL,
- DPMSolverMultistepScheduler,
- UniDiffuserModel,
- UniDiffuserPipeline,
- UniDiffuserTextDecoder,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class UniDiffuserPipelineFastTests(
- PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
-):
- pipeline_class = UniDiffuserPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- # vae_latents, not latents, is the argument that corresponds to VAE latent inputs
- image_latents_params = frozenset(["vae_latents"])
-
- supports_dduf = False
-
- def get_dummy_components(self):
- unet = UniDiffuserModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="unet",
- )
-
- scheduler = DPMSolverMultistepScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- solver_order=3,
- )
-
- vae = AutoencoderKL.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="vae",
- )
-
- text_encoder = CLIPTextModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_encoder",
- )
- clip_tokenizer = CLIPTokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="clip_tokenizer",
- )
-
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="image_encoder",
- )
- # From the Stable Diffusion Image Variation pipeline tests
- clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
- # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_tokenizer = GPT2Tokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_tokenizer",
- )
- text_decoder = UniDiffuserTextDecoder.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_decoder",
- )
-
- components = {
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "clip_image_processor": clip_image_processor,
- "clip_tokenizer": clip_tokenizer,
- "text_decoder": text_decoder,
- "text_tokenizer": text_tokenizer,
- "unet": unet,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- generator = torch.Generator(device=device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def get_dummy_inputs_with_latents(self, device, seed=0):
- # image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- # image = image.cpu().permute(0, 2, 3, 1)[0]
- # image = Image.fromarray(np.uint8(image)).convert("RGB")
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg",
- )
- image = image.resize((32, 32))
- latents = self.get_fixed_latents(device, seed=seed)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "prompt_latents": latents.get("prompt_latents"),
- "vae_latents": latents.get("vae_latents"),
- "clip_latents": latents.get("clip_latents"),
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.7489, 0.3722, 0.4475, 0.5630, 0.5923, 0.4992, 0.3936, 0.5844, 0.4975])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_unidiffuser_default_joint_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_no_cfg_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- # Set guidance scale to 1.0 to turn off CFG
- inputs["guidance_scale"] = 1.0
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_image_0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_image_mode()
- assert unidiffuser_pipe.mode == "img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_text_mode()
- assert unidiffuser_pipe.mode == "text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_img2text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_img2text_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_text2img_multiple_images(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_unidiffuser_text2img_multiple_images_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=2e-4)
-
- @require_torch_accelerator
- def test_unidiffuser_default_joint_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @require_torch_accelerator
- def test_unidiffuser_default_text2img_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- @require_torch_accelerator
- def test_unidiffuser_default_img2text_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- inputs["data_type"] = 1
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @unittest.skip(
- "Test not supported because it has a bunch of direct configs at init and also, this pipeline isn't used that much now."
- )
- def test_encode_prompt_works_in_isolation():
- pass
-
-
-@nightly
-@require_torch_accelerator
-class UniDiffuserPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
-
-@nightly
-@require_torch_accelerator
-class UniDiffuserPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 2e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
diff --git a/tests/pipelines/wuerstchen/__init__.py b/tests/pipelines/wuerstchen/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
deleted file mode 100644
index 060a11434e..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenCombinedPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "prior_guidance_scale",
- "decoder_guidance_scale",
- "negative_prompt",
- "num_inference_steps",
- "return_dict",
- "prior_num_inference_steps",
- "output_type",
- ]
- test_xformers_attention = True
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2}
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_prior_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- prior_text_encoder = self.dummy_prior_text_encoder
-
- scheduler = DDPMWuerstchenScheduler()
- tokenizer = self.dummy_tokenizer
-
- text_encoder = self.dummy_text_encoder
- decoder = self.dummy_decoder
- vqgan = self.dummy_vqgan
-
- components = {
- "tokenizer": tokenizer,
- "text_encoder": text_encoder,
- "decoder": decoder,
- "vqgan": vqgan,
- "scheduler": scheduler,
- "prior_prior": prior,
- "prior_text_encoder": prior_text_encoder,
- "prior_tokenizer": tokenizer,
- "prior_scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_guidance_scale": 4.0,
- "decoder_guidance_scale": 4.0,
- "num_inference_steps": 2,
- "prior_num_inference_steps": 2,
- "output_type": "np",
- "height": 128,
- "width": 128,
- }
- return inputs
-
- def test_wuerstchen(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
-
- assert image.shape == (1, 128, 128, 3)
-
- expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
-
- @require_torch_accelerator
- def test_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=1e-2)
-
- @unittest.skip(reason="flakey and float16 requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_inputs(self):
- pass
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_cfg(self):
- pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
deleted file mode 100644
index 5d2462d48d..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenDecoderPipeline
- params = ["prompt"]
- batch_params = ["image_embeddings", "prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- vqgan = self.dummy_vqgan
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "decoder": decoder,
- "vqgan": vqgan,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "latent_dim_scale": 4.0,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image_embeddings": torch.ones((1, 4, 4, 4), device=device),
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 1.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_decoder(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=1e-5)
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="bf16 not supported and requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip("Test not supported.")
- def test_encode_prompt_works_in_isolation(self):
- super().test_encode_prompt_works_in_isolation()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
deleted file mode 100644
index 34f7c684b7..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
-from diffusers.pipelines.wuerstchen import WuerstchenPrior
-from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
-
-
-if is_peft_available():
- from peft import LoraConfig
- from peft.tuners.tuners_utils import BaseTunerLayer
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenPriorPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "generator",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_in": 2,
- "c": 8,
- "depth": 2,
- "c_cond": 32,
- "c_r": 8,
- "nhead": 2,
- }
-
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "prior": prior,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 4.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_prior(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.image_embeddings
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, 0, 0, -10:]
- image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
- assert image.shape == (1, 2, 24, 24)
-
- expected_slice = np.array(
- [
- -7172.837,
- -3438.855,
- -1093.312,
- 388.8835,
- -7471.467,
- -7998.1206,
- -5328.259,
- 218.00089,
- -2731.5745,
- -8056.734,
- ]
- )
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(
- expected_max_diff=3e-1,
- )
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="flaky for now")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- # override because we need to make sure latent_mean and latent_std to be 0
- def test_callback_inputs(self):
- components = self.get_dummy_components()
- components["latent_mean"] = 0
- components["latent_std"] = 0
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
- def check_if_lora_correctly_set(self, model) -> bool:
- """
- Checks if the LoRA layers are correctly set with peft
- """
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- return True
- return False
-
- def get_lora_components(self):
- prior = self.dummy_prior
-
- prior_lora_config = LoraConfig(
- r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
- )
-
- return prior, prior_lora_config
-
- @require_peft_backend
- def test_inference_with_prior_lora(self):
- _, prior_lora_config = self.get_lora_components()
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output_no_lora = pipe(**self.get_dummy_inputs(device))
- image_embed = output_no_lora.image_embeddings
- self.assertTrue(image_embed.shape == (1, 2, 24, 24))
-
- pipe.prior.add_adapter(prior_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
-
- output_lora = pipe(**self.get_dummy_inputs(device))
- lora_image_embed = output_lora.image_embeddings
-
- self.assertTrue(image_embed.shape == lora_image_embed.shape)
-
- @unittest.skip("Test not supported as dtype cannot be inferred without the text encoder otherwise.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index c5497d1c8d..98005cfbc8 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -865,6 +872,7 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
+@require_bitsandbytes_version_greater("0.45.5")
class Bnb4BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index 383cdd6849..f3bbc34e8b 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -830,6 +837,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
+@require_bitsandbytes_version_greater("0.45.5")
class Bnb8BitCompileTests(QuantCompileTests):
@property
def quantization_config(self):
diff --git a/tests/pipelines/audioldm/__init__.py b/tests/quantization/gguf/__init__.py
similarity index 100%
rename from tests/pipelines/audioldm/__init__.py
rename to tests/quantization/gguf/__init__.py
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index 5d1fa4c22e..fe56f890ee 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -8,6 +8,7 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
+ DiffusionPipeline,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
@@ -15,6 +16,8 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -30,9 +33,12 @@ from diffusers.utils.testing_utils import (
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_peft_backend,
+ require_torch_version_greater,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
+
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@@ -577,3 +583,99 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
+
+
+class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_image": torch.randn(
+ (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanVACETransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+@require_torch_version_greater("2.7.1")
+class GGUFCompileTests(QuantCompileTests):
+ torch_dtype = torch.bfloat16
+ gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+
+ @property
+ def quantization_config(self):
+ return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+
+ def test_torch_compile(self):
+ super()._test_torch_compile(quantization_config=self.quantization_config)
+
+ def test_torch_compile_with_cpu_offload(self):
+ super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
+
+ def test_torch_compile_with_group_offload_leaf(self):
+ super()._test_torch_compile_with_group_offload_leaf(quantization_config=self.quantization_config)
+
+ def _init_pipeline(self, *args, **kwargs):
+ transformer = FluxTransformer2DModel.from_single_file(
+ self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ return pipe
diff --git a/utils/print_env.py b/utils/print_env.py
index 2d2acb59d5..2fe0777daf 100644
--- a/utils/print_env.py
+++ b/utils/print_env.py
@@ -28,6 +28,16 @@ print("Python version:", sys.version)
print("OS platform:", platform.platform())
print("OS architecture:", platform.machine())
+try:
+ import psutil
+
+ vm = psutil.virtual_memory()
+ total_gb = vm.total / (1024**3)
+ available_gb = vm.available / (1024**3)
+ print(f"Total RAM: {total_gb:.2f} GB")
+ print(f"Available RAM: {available_gb:.2f} GB")
+except ImportError:
+ pass
try:
import torch