diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 696097fd54..747e1d8154 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -11,17 +11,18 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
+ BASE_PATH: benchmark_outputs
jobs:
- torch_pipelines_cuda_benchmark_tests:
+ torch_models_cuda_benchmark_tests:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
- name: Torch Core Pipelines CUDA Benchmarking Tests
+ name: Torch Core Models CUDA Benchmarking Tests
strategy:
fail-fast: false
max-parallel: 1
runs-on:
- group: aws-g6-4xlarge-plus
+ group: aws-g6e-4xlarge
container:
image: diffusers/diffusers-pytorch-cuda
options: --shm-size "16gb" --ipc host --gpus 0
@@ -35,27 +36,47 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
+ apt update
+ apt install -y libpq-dev postgresql-client
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test]
- python -m uv pip install pandas peft
- python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
+ python -m uv pip install -r benchmarks/requirements.txt
- name: Environment
run: |
python utils/print_env.py
- name: Diffusers Benchmarking
env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
- BASE_PATH: benchmark_outputs
+ HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
run: |
- export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
- cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
+ cd benchmarks && python run_all.py
+
+ - name: Push results to the Hub
+ env:
+ HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
+ run: |
+ cd benchmarks && python push_results.py
+ mkdir $BASE_PATH && cp *.csv $BASE_PATH
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
- path: benchmarks/benchmark_outputs
+ path: benchmarks/${{ env.BASE_PATH }}
+
+ # TODO: enable this once the connection problem has been resolved.
+ - name: Update benchmarking results to DB
+ env:
+ PGDATABASE: metrics
+ PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
+ PGUSER: transformers_benchmarks
+ PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
+ BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
+ run: |
+ git config --global --add safe.directory /__w/diffusers/diffusers
+ commit_id=$GITHUB_SHA
+ commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
+ cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}
diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml
index 838f241ddc..583853c6d6 100644
--- a/.github/workflows/build_docker_images.yml
+++ b/.github/workflows/build_docker_images.yml
@@ -75,10 +75,6 @@ jobs:
- diffusers-pytorch-cuda
- diffusers-pytorch-xformers-cuda
- diffusers-pytorch-minimum-cuda
- - diffusers-flax-cpu
- - diffusers-flax-tpu
- - diffusers-onnxruntime-cpu
- - diffusers-onnxruntime-cuda
- diffusers-doc-builder
steps:
diff --git a/.github/workflows/mirror_community_pipeline.yml b/.github/workflows/mirror_community_pipeline.yml
index f6eff1bbd8..9cf573312b 100644
--- a/.github/workflows/mirror_community_pipeline.yml
+++ b/.github/workflows/mirror_community_pipeline.yml
@@ -79,14 +79,14 @@ jobs:
# Check secret is set
- name: whoami
- run: huggingface-cli whoami
+ run: hf auth whoami
env:
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
# Push to HF! (under subfolder based on checkout ref)
# https://huggingface.co/datasets/diffusers/community-pipelines-mirror
- name: Mirror community pipeline to HF
- run: huggingface-cli upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
+ run: hf upload diffusers/community-pipelines-mirror ./examples/community ${PATH_IN_REPO} --repo-type dataset
env:
PATH_IN_REPO: ${{ env.PATH_IN_REPO }}
HF_TOKEN: ${{ secrets.HF_TOKEN_MIRROR_COMMUNITY_PIPELINES }}
diff --git a/.github/workflows/nightly_tests.yml b/.github/workflows/nightly_tests.yml
index 5476616704..384f07506a 100644
--- a/.github/workflows/nightly_tests.yml
+++ b/.github/workflows/nightly_tests.yml
@@ -248,7 +248,7 @@ jobs:
BIG_GPU_MEMORY: 40
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -m "big_gpu_with_torch_cuda" \
+ -m "big_accelerator" \
--make-reports=tests_big_gpu_torch_cuda \
--report-log=tests_big_gpu_torch_cuda.log \
tests/
@@ -321,55 +321,6 @@ jobs:
name: torch_minimum_version_cuda_test_reports
path: reports
- run_nightly_onnx_tests:
- name: Nightly ONNXRuntime CUDA tests on Ubuntu
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --gpus 0 --shm-size "16gb" --ipc host
-
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: NVIDIA-SMI
- run: nvidia-smi
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
- - name: Environment
- run: python utils/print_env.py
-
- - name: Run Nightly ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- --report-log=tests_onnx_cuda.log \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: tests_onnx_cuda_reports
- path: reports
-
run_nightly_quantization_tests:
name: Torch quantization nightly tests
strategy:
@@ -485,57 +436,6 @@ jobs:
name: torch_cuda_pipeline_level_quant_reports
path: reports
- run_flax_tpu_tests:
- name: Nightly Flax TPU Tests
- runs-on:
- group: gcp-ct5lp-hightpu-8t
- if: github.event_name == 'schedule'
-
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
- python -m uv pip install pytest-reportlog
-
- - name: Environment
- run: python utils/print_env.py
-
- - name: Run nightly Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- --report-log=tests_flax_tpu.log \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
generate_consolidated_report:
name: Generate Consolidated Test Report
needs: [
@@ -545,9 +445,9 @@ jobs:
run_big_gpu_torch_tests,
run_nightly_quantization_tests,
run_nightly_pipeline_level_quantization_tests,
- run_nightly_onnx_tests,
+ # run_nightly_onnx_tests,
torch_minimum_version_cuda_tests,
- run_flax_tpu_tests
+ # run_flax_tpu_tests
]
if: always()
runs-on:
diff --git a/.github/workflows/pr_tests.yml b/.github/workflows/pr_tests.yml
index a0bf1e79e8..34a344528e 100644
--- a/.github/workflows/pr_tests.yml
+++ b/.github/workflows/pr_tests.yml
@@ -87,11 +87,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_models_schedulers
- - name: Fast Flax CPU tests
- framework: flax
- runner: aws-general-8-plus
- image: diffusers/diffusers-flax-cpu
- report: flax_cpu
- name: PyTorch Example CPU tests
framework: pytorch_examples
runner: aws-general-8-plus
@@ -147,15 +142,6 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/models tests/schedulers tests/others
- - name: Run fast Flax TPU tests
- if: ${{ matrix.config.framework == 'flax' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Flax" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests
-
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
diff --git a/.github/workflows/pr_tests_gpu.yml b/.github/workflows/pr_tests_gpu.yml
index 87d5177388..bb74daad21 100644
--- a/.github/workflows/pr_tests_gpu.yml
+++ b/.github/workflows/pr_tests_gpu.yml
@@ -13,6 +13,7 @@ on:
- "src/diffusers/loaders/peft.py"
- "tests/pipelines/test_pipelines_common.py"
- "tests/models/test_modeling_common.py"
+ - "examples/**/*.py"
workflow_dispatch:
concurrency:
@@ -188,7 +189,7 @@ jobs:
shell: bash
strategy:
fail-fast: false
- max-parallel: 2
+ max-parallel: 4
matrix:
module: [models, schedulers, lora, others]
steps:
diff --git a/.github/workflows/push_tests.yml b/.github/workflows/push_tests.yml
index 7cab08b44f..007770c8ed 100644
--- a/.github/workflows/push_tests.yml
+++ b/.github/workflows/push_tests.yml
@@ -159,102 +159,6 @@ jobs:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
- flax_tpu_tests:
- name: Flax TPU Tests
- runs-on:
- group: gcp-ct5lp-hightpu-8t
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
- onnx_cuda_tests:
- name: ONNX CUDA Tests
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: onnx_cuda_test_reports
- path: reports
-
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
diff --git a/.github/workflows/push_tests_fast.yml b/.github/workflows/push_tests_fast.yml
index e8a73446de..e274cb0218 100644
--- a/.github/workflows/push_tests_fast.yml
+++ b/.github/workflows/push_tests_fast.yml
@@ -33,16 +33,6 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu
- - name: Fast Flax CPU tests on Ubuntu
- framework: flax
- runner: aws-general-8-plus
- image: diffusers/diffusers-flax-cpu
- report: flax_cpu
- - name: Fast ONNXRuntime CPU tests on Ubuntu
- framework: onnxruntime
- runner: aws-general-8-plus
- image: diffusers/diffusers-onnxruntime-cpu
- report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: aws-general-8-plus
@@ -87,24 +77,6 @@ jobs:
--make-reports=tests_${{ matrix.config.report }} \
tests/
- - name: Run fast Flax TPU tests
- if: ${{ matrix.config.framework == 'flax' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Flax" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests/
-
- - name: Run fast ONNXRuntime CPU tests
- if: ${{ matrix.config.framework == 'onnxruntime' }}
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_${{ matrix.config.report }} \
- tests/
-
- name: Run example PyTorch CPU tests
if: ${{ matrix.config.framework == 'pytorch_examples' }}
run: |
diff --git a/.github/workflows/push_tests_mps.yml b/.github/workflows/push_tests_mps.yml
index 5fd3b78be7..eb6c0da225 100644
--- a/.github/workflows/push_tests_mps.yml
+++ b/.github/workflows/push_tests_mps.yml
@@ -1,12 +1,7 @@
name: Fast mps tests on main
on:
- push:
- branches:
- - main
- paths:
- - "src/diffusers/**.py"
- - "tests/**.py"
+ workflow_dispatch:
env:
DIFFUSERS_IS_CI: yes
diff --git a/.github/workflows/release_tests_fast.yml b/.github/workflows/release_tests_fast.yml
index a464381ba4..e5d3282049 100644
--- a/.github/workflows/release_tests_fast.yml
+++ b/.github/workflows/release_tests_fast.yml
@@ -213,101 +213,6 @@ jobs:
with:
name: torch_minimum_version_cuda_test_reports
path: reports
-
- flax_tpu_tests:
- name: Flax TPU Tests
- runs-on: docker-tpu
- container:
- image: diffusers/diffusers-flax-tpu
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run slow Flax TPU tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 0 \
- -s -v -k "Flax" \
- --make-reports=tests_flax_tpu \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_flax_tpu_stats.txt
- cat reports/tests_flax_tpu_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: flax_tpu_test_reports
- path: reports
-
- onnx_cuda_tests:
- name: ONNX CUDA Tests
- runs-on:
- group: aws-g4dn-2xlarge
- container:
- image: diffusers/diffusers-onnxruntime-cuda
- options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
- defaults:
- run:
- shell: bash
- steps:
- - name: Checkout diffusers
- uses: actions/checkout@v3
- with:
- fetch-depth: 2
-
- - name: Install dependencies
- run: |
- python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
- python -m uv pip install -e [quality,test]
- pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
-
- - name: Environment
- run: |
- python utils/print_env.py
-
- - name: Run slow ONNXRuntime CUDA tests
- env:
- HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
- run: |
- python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
- -s -v -k "Onnx" \
- --make-reports=tests_onnx_cuda \
- tests/
-
- - name: Failure short reports
- if: ${{ failure() }}
- run: |
- cat reports/tests_onnx_cuda_stats.txt
- cat reports/tests_onnx_cuda_failures_short.txt
-
- - name: Test suite reports artifacts
- if: ${{ always() }}
- uses: actions/upload-artifact@v4
- with:
- name: onnx_cuda_test_reports
- path: reports
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
diff --git a/benchmarks/README.md b/benchmarks/README.md
new file mode 100644
index 0000000000..afab1b0de3
--- /dev/null
+++ b/benchmarks/README.md
@@ -0,0 +1,69 @@
+# Diffusers Benchmarks
+
+Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
+
+* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
+* Base + `torch.compile()`
+* NF4 quantization
+* Layerwise upcasting
+
+Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
+
+The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
+
+The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
+
+## Running the benchmarks manually
+
+First set up `torch` and install `diffusers` from the root of the directory:
+
+```py
+pip install -e ".[quality,test]"
+```
+
+Then make sure the other dependencies are installed:
+
+```sh
+cd benchmarks/
+pip install -r requirements.txt
+```
+
+We need to be authenticated to access some of the checkpoints used during benchmarking:
+
+```sh
+hf auth login
+```
+
+We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
+
+Then you can either launch the entire benchmarking suite by running:
+
+```sh
+python run_all.py
+```
+
+Or, you can run the individual benchmarks.
+
+## Customizing the benchmarks
+
+We define "scenarios" to cover the most common ways in which these models are used. You can
+define a new scenario, modifying an existing benchmark file:
+
+```py
+BenchmarkScenario(
+ name=f"{CKPT_ID}-bnb-8bit",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ "quantization_config": BitsAndBytesConfig(load_in_8bit=True),
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+)
+```
+
+You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
+
+Happy benchmarking 🧨
\ No newline at end of file
diff --git a/tests/pipelines/amused/__init__.py b/benchmarks/__init__.py
similarity index 100%
rename from tests/pipelines/amused/__init__.py
rename to benchmarks/__init__.py
diff --git a/benchmarks/base_classes.py b/benchmarks/base_classes.py
deleted file mode 100644
index 45bf65c93c..0000000000
--- a/benchmarks/base_classes.py
+++ /dev/null
@@ -1,346 +0,0 @@
-import os
-import sys
-
-import torch
-
-from diffusers import (
- AutoPipelineForImage2Image,
- AutoPipelineForInpainting,
- AutoPipelineForText2Image,
- ControlNetModel,
- LCMScheduler,
- StableDiffusionAdapterPipeline,
- StableDiffusionControlNetPipeline,
- StableDiffusionXLAdapterPipeline,
- StableDiffusionXLControlNetPipeline,
- T2IAdapter,
- WuerstchenCombinedPipeline,
-)
-from diffusers.utils import load_image
-
-
-sys.path.append(".")
-
-from utils import ( # noqa: E402
- BASE_PATH,
- PROMPT,
- BenchmarkInfo,
- benchmark_fn,
- bytes_to_giga_bytes,
- flush,
- generate_csv_dict,
- write_to_csv,
-)
-
-
-RESOLUTION_MAPPING = {
- "Lykon/DreamShaper": (512, 512),
- "lllyasviel/sd-controlnet-canny": (512, 512),
- "diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
- "TencentARC/t2iadapter_canny_sd14v1": (512, 512),
- "TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
- "stabilityai/stable-diffusion-2-1": (768, 768),
- "stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
- "stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
- "stabilityai/sdxl-turbo": (512, 512),
-}
-
-
-class BaseBenchmak:
- pipeline_class = None
-
- def __init__(self, args):
- super().__init__()
-
- def run_inference(self, args):
- raise NotImplementedError
-
- def benchmark(self, args):
- raise NotImplementedError
-
- def get_result_filepath(self, args):
- pipeline_class_name = str(self.pipe.__class__.__name__)
- name = (
- args.ckpt.replace("/", "_")
- + "_"
- + pipeline_class_name
- + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
- )
- filepath = os.path.join(BASE_PATH, name)
- return filepath
-
-
-class TextToImageBenchmark(BaseBenchmak):
- pipeline_class = AutoPipelineForText2Image
-
- def __init__(self, args):
- pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- if args.run_compile:
- if not isinstance(pipe, WuerstchenCombinedPipeline):
- pipe.unet.to(memory_format=torch.channels_last)
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
- pipe.movq.to(memory_format=torch.channels_last)
- pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
- else:
- print("Run torch compile")
- pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
- pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
- def benchmark(self, args):
- flush()
-
- print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
-
- time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
- memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
- benchmark_info = BenchmarkInfo(time=time, memory=memory)
-
- pipeline_class_name = str(self.pipe.__class__.__name__)
- flush()
- csv_dict = generate_csv_dict(
- pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
- )
- filepath = self.get_result_filepath(args)
- write_to_csv(filepath, csv_dict)
- print(f"Logs written to: {filepath}")
- flush()
-
-
-class TurboTextToImageBenchmark(TextToImageBenchmark):
- def __init__(self, args):
- super().__init__(args)
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=0.0,
- )
-
-
-class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
- lora_id = "latent-consistency/lcm-lora-sdxl"
-
- def __init__(self, args):
- super().__init__(args)
- self.pipe.load_lora_weights(self.lora_id)
- self.pipe.fuse_lora()
- self.pipe.unload_lora_weights()
- self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
-
- def get_result_filepath(self, args):
- pipeline_class_name = str(self.pipe.__class__.__name__)
- name = (
- self.lora_id.replace("/", "_")
- + "_"
- + pipeline_class_name
- + f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
- )
- filepath = os.path.join(BASE_PATH, name)
- return filepath
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=1.0,
- )
-
- def benchmark(self, args):
- flush()
-
- print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
-
- time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
- memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
- benchmark_info = BenchmarkInfo(time=time, memory=memory)
-
- pipeline_class_name = str(self.pipe.__class__.__name__)
- flush()
- csv_dict = generate_csv_dict(
- pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
- )
- filepath = self.get_result_filepath(args)
- write_to_csv(filepath, csv_dict)
- print(f"Logs written to: {filepath}")
- flush()
-
-
-class ImageToImageBenchmark(TextToImageBenchmark):
- pipeline_class = AutoPipelineForImage2Image
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
- image = load_image(url).convert("RGB")
-
- def __init__(self, args):
- super().__init__(args)
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class TurboImageToImageBenchmark(ImageToImageBenchmark):
- def __init__(self, args):
- super().__init__(args)
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- guidance_scale=0.0,
- strength=0.5,
- )
-
-
-class InpaintingBenchmark(ImageToImageBenchmark):
- pipeline_class = AutoPipelineForInpainting
- mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
- mask = load_image(mask_url).convert("RGB")
-
- def __init__(self, args):
- super().__init__(args)
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
- self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- mask_image=self.mask,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
- url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
- image = load_image(url)
-
- def __init__(self, args):
- pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
- pipe.load_ip_adapter(
- args.ip_adapter_id[0],
- subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
- weight_name=args.ip_adapter_id[1],
- )
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- ip_adapter_image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class ControlNetBenchmark(TextToImageBenchmark):
- pipeline_class = StableDiffusionControlNetPipeline
- aux_network_class = ControlNetModel
- root_ckpt = "Lykon/DreamShaper"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
- image = load_image(url).convert("RGB")
-
- def __init__(self, args):
- aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.controlnet.to(memory_format=torch.channels_last)
-
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
-
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
- def run_inference(self, pipe, args):
- _ = pipe(
- prompt=PROMPT,
- image=self.image,
- num_inference_steps=args.num_inference_steps,
- num_images_per_prompt=args.batch_size,
- )
-
-
-class ControlNetSDXLBenchmark(ControlNetBenchmark):
- pipeline_class = StableDiffusionXLControlNetPipeline
- root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
-
- def __init__(self, args):
- super().__init__(args)
-
-
-class T2IAdapterBenchmark(ControlNetBenchmark):
- pipeline_class = StableDiffusionAdapterPipeline
- aux_network_class = T2IAdapter
- root_ckpt = "Lykon/DreamShaper"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
- image = load_image(url).convert("L")
-
- def __init__(self, args):
- aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
- pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
- pipe = pipe.to("cuda")
-
- pipe.set_progress_bar_config(disable=True)
- self.pipe = pipe
-
- if args.run_compile:
- pipe.unet.to(memory_format=torch.channels_last)
- pipe.adapter.to(memory_format=torch.channels_last)
-
- print("Run torch compile")
- pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
- pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
-
- self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
-
-
-class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
- pipeline_class = StableDiffusionXLAdapterPipeline
- root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
-
- url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
- image = load_image(url)
-
- def __init__(self, args):
- super().__init__(args)
diff --git a/benchmarks/benchmark_controlnet.py b/benchmarks/benchmark_controlnet.py
deleted file mode 100644
index 9217004461..0000000000
--- a/benchmarks/benchmark_controlnet.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="lllyasviel/sd-controlnet-canny",
- choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = (
- ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
- )
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_ip_adapters.py b/benchmarks/benchmark_ip_adapters.py
deleted file mode 100644
index 9a31a21fc6..0000000000
--- a/benchmarks/benchmark_ip_adapters.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
-
-
-IP_ADAPTER_CKPTS = {
- # because original SD v1.5 has been taken down.
- "Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
- "stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
-}
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="rstabilityai/stable-diffusion-xl-base-1.0",
- choices=list(IP_ADAPTER_CKPTS.keys()),
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
- benchmark_pipe = IPAdapterTextToImageBenchmark(args)
- args.ckpt = f"{args.ckpt} (IP-Adapter)"
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_sd_img.py b/benchmarks/benchmark_sd_img.py
deleted file mode 100644
index 772befe879..0000000000
--- a/benchmarks/benchmark_sd_img.py
+++ /dev/null
@@ -1,29 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=[
- "Lykon/DreamShaper",
- "stabilityai/stable-diffusion-2-1",
- "stabilityai/stable-diffusion-xl-refiner-1.0",
- "stabilityai/sdxl-turbo",
- ],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_sd_inpainting.py b/benchmarks/benchmark_sd_inpainting.py
deleted file mode 100644
index 143adcb0d8..0000000000
--- a/benchmarks/benchmark_sd_inpainting.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import InpaintingBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=[
- "Lykon/DreamShaper",
- "stabilityai/stable-diffusion-2-1",
- "stabilityai/stable-diffusion-xl-base-1.0",
- ],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = InpaintingBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_t2i_adapter.py b/benchmarks/benchmark_t2i_adapter.py
deleted file mode 100644
index 44b04b470e..0000000000
--- a/benchmarks/benchmark_t2i_adapter.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="TencentARC/t2iadapter_canny_sd14v1",
- choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = (
- T2IAdapterBenchmark(args)
- if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
- else T2IAdapterSDXLBenchmark(args)
- )
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_t2i_lcm_lora.py b/benchmarks/benchmark_t2i_lcm_lora.py
deleted file mode 100644
index 957e0a463e..0000000000
--- a/benchmarks/benchmark_t2i_lcm_lora.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="stabilityai/stable-diffusion-xl-base-1.0",
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=4)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_pipe = LCMLoRATextToImageBenchmark(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmark_text_to_image.py b/benchmarks/benchmark_text_to_image.py
deleted file mode 100644
index ddc7fb2676..0000000000
--- a/benchmarks/benchmark_text_to_image.py
+++ /dev/null
@@ -1,40 +0,0 @@
-import argparse
-import sys
-
-
-sys.path.append(".")
-from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
-
-
-ALL_T2I_CKPTS = [
- "Lykon/DreamShaper",
- "segmind/SSD-1B",
- "stabilityai/stable-diffusion-xl-base-1.0",
- "kandinsky-community/kandinsky-2-2-decoder",
- "warp-ai/wuerstchen",
- "stabilityai/sdxl-turbo",
-]
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--ckpt",
- type=str,
- default="Lykon/DreamShaper",
- choices=ALL_T2I_CKPTS,
- )
- parser.add_argument("--batch_size", type=int, default=1)
- parser.add_argument("--num_inference_steps", type=int, default=50)
- parser.add_argument("--model_cpu_offload", action="store_true")
- parser.add_argument("--run_compile", action="store_true")
- args = parser.parse_args()
-
- benchmark_cls = None
- if "turbo" in args.ckpt:
- benchmark_cls = TurboTextToImageBenchmark
- else:
- benchmark_cls = TextToImageBenchmark
-
- benchmark_pipe = benchmark_cls(args)
- benchmark_pipe.benchmark(args)
diff --git a/benchmarks/benchmarking_flux.py b/benchmarks/benchmarking_flux.py
new file mode 100644
index 0000000000..18a2680052
--- /dev/null
+++ b/benchmarks/benchmarking_flux.py
@@ -0,0 +1,98 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "black-forest-labs/FLUX.1-dev"
+RESULT_FILENAME = "flux.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # resolution: 1024x1024
+ # maximum sequence length 512
+ hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
+ pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
+ image_ids = torch.ones(512, 3, **device_dtype_kwargs)
+ text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ guidance = torch.tensor([1.0], **device_dtype_kwargs)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "img_ids": image_ids,
+ "txt_ids": text_ids,
+ "pooled_projections": pooled_prompt_embeds,
+ "timestep": timestep,
+ "guidance": guidance,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bnb-nf4",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ "quantization_config": BitsAndBytesConfig(
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
+ ),
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=FluxTransformer2DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_ltx.py b/benchmarks/benchmarking_ltx.py
new file mode 100644
index 0000000000..3d698fd0bd
--- /dev/null
+++ b/benchmarks/benchmarking_ltx.py
@@ -0,0 +1,80 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import LTXVideoTransformer3DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
+RESULT_FILENAME = "ltx.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # 512x704 (161 frames)
+ # `max_sequence_length`: 256
+ hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
+ encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "encoder_attention_mask": encoder_attention_mask,
+ "timestep": timestep,
+ "video_coords": video_coords,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=LTXVideoTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_sdxl.py b/benchmarks/benchmarking_sdxl.py
new file mode 100644
index 0000000000..ded62784f2
--- /dev/null
+++ b/benchmarks/benchmarking_sdxl.py
@@ -0,0 +1,82 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import UNet2DConditionModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
+RESULT_FILENAME = "sdxl.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # height: 1024
+ # width: 1024
+ # max_sequence_length: 77
+ hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+ added_cond_kwargs = {
+ "text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
+ "time_ids": torch.ones(1, 6, **device_dtype_kwargs),
+ }
+
+ return {
+ "sample": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "added_cond_kwargs": added_cond_kwargs,
+ }
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=UNet2DConditionModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "unet",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/benchmarking_utils.py b/benchmarks/benchmarking_utils.py
new file mode 100644
index 0000000000..c8c1a10ef8
--- /dev/null
+++ b/benchmarks/benchmarking_utils.py
@@ -0,0 +1,244 @@
+import gc
+import inspect
+import logging
+import os
+import queue
+import threading
+from contextlib import nullcontext
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Optional, Union
+
+import pandas as pd
+import torch
+import torch.utils.benchmark as benchmark
+
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils.testing_utils import require_torch_gpu, torch_device
+
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+NUM_WARMUP_ROUNDS = 5
+
+
+def benchmark_fn(f, *args, **kwargs):
+ t0 = benchmark.Timer(
+ stmt="f(*args, **kwargs)",
+ globals={"args": args, "kwargs": kwargs, "f": f},
+ num_threads=1,
+ )
+ return float(f"{(t0.blocked_autorange().mean):.3f}")
+
+
+def flush():
+ gc.collect()
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ torch.cuda.reset_peak_memory_stats()
+
+
+# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
+def calculate_flops(model, input_dict):
+ try:
+ from torchprofile import profile_macs
+ except ModuleNotFoundError:
+ raise
+
+ # This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
+ sig = inspect.signature(model.forward)
+ param_names = [
+ p.name
+ for p in sig.parameters.values()
+ if p.kind
+ in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ )
+ and p.name != "self"
+ ]
+ bound = sig.bind_partial(**input_dict)
+ bound.apply_defaults()
+ args = tuple(bound.arguments[name] for name in param_names)
+
+ model.eval()
+ with torch.no_grad():
+ macs = profile_macs(model, args)
+ flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
+ return flops
+
+
+def calculate_params(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+# Users can define their own in case this doesn't suffice. For most cases,
+# it should be sufficient.
+def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
+ model = model_cls.from_pretrained(**init_kwargs).eval()
+ if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
+ model.enable_group_offload(**group_offload_kwargs)
+ else:
+ model.to(torch_device)
+ if layerwise_upcasting:
+ model.enable_layerwise_casting(
+ storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
+ )
+ return model
+
+
+@dataclass
+class BenchmarkScenario:
+ name: str
+ model_cls: ModelMixin
+ model_init_kwargs: Dict[str, Any]
+ model_init_fn: Callable
+ get_model_input_dict: Callable
+ compile_kwargs: Optional[Dict[str, Any]] = None
+
+
+@require_torch_gpu
+class BenchmarkMixin:
+ def pre_benchmark(self):
+ flush()
+ torch.compiler.reset()
+
+ def post_benchmark(self, model):
+ model.cpu()
+ flush()
+ torch.compiler.reset()
+
+ @torch.no_grad()
+ def run_benchmark(self, scenario: BenchmarkScenario):
+ # 0) Basic stats
+ logger.info(f"Running scenario: {scenario.name}.")
+ try:
+ model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
+ num_params = round(calculate_params(model) / 1e9, 2)
+ try:
+ flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
+ except Exception as e:
+ logger.info(f"Problem in calculating FLOPs:\n{e}")
+ flops = None
+ model.cpu()
+ del model
+ except Exception as e:
+ logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
+ return {}
+ self.pre_benchmark()
+
+ # 1) plain stats
+ results = {}
+ plain = None
+ try:
+ plain = self._run_phase(
+ model_cls=scenario.model_cls,
+ init_fn=scenario.model_init_fn,
+ init_kwargs=scenario.model_init_kwargs,
+ get_input_fn=scenario.get_model_input_dict,
+ compile_kwargs=None,
+ )
+ except Exception as e:
+ logger.info(f"Benchmark could not be run with the following error:\n{e}")
+ return results
+
+ # 2) compiled stats (if any)
+ compiled = {"time": None, "memory": None}
+ if scenario.compile_kwargs:
+ try:
+ compiled = self._run_phase(
+ model_cls=scenario.model_cls,
+ init_fn=scenario.model_init_fn,
+ init_kwargs=scenario.model_init_kwargs,
+ get_input_fn=scenario.get_model_input_dict,
+ compile_kwargs=scenario.compile_kwargs,
+ )
+ except Exception as e:
+ logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
+ if plain is None:
+ return results
+
+ # 3) merge
+ result = {
+ "scenario": scenario.name,
+ "model_cls": scenario.model_cls.__name__,
+ "num_params_B": num_params,
+ "flops_G": flops,
+ "time_plain_s": plain["time"],
+ "mem_plain_GB": plain["memory"],
+ "time_compile_s": compiled["time"],
+ "mem_compile_GB": compiled["memory"],
+ }
+ if scenario.compile_kwargs:
+ result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
+ result["mode"] = scenario.compile_kwargs.get("mode", "default")
+ else:
+ result["fullgraph"], result["mode"] = None, None
+ return result
+
+ def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
+ if not isinstance(scenarios, list):
+ scenarios = [scenarios]
+ record_queue = queue.Queue()
+ stop_signal = object()
+
+ def _writer_thread():
+ while True:
+ item = record_queue.get()
+ if item is stop_signal:
+ break
+ df_row = pd.DataFrame([item])
+ write_header = not os.path.exists(filename)
+ df_row.to_csv(filename, mode="a", header=write_header, index=False)
+ record_queue.task_done()
+
+ record_queue.task_done()
+
+ writer = threading.Thread(target=_writer_thread, daemon=True)
+ writer.start()
+
+ for s in scenarios:
+ try:
+ record = self.run_benchmark(s)
+ if record:
+ record_queue.put(record)
+ else:
+ logger.info(f"Record empty from scenario: {s.name}.")
+ except Exception as e:
+ logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
+ record_queue.put(stop_signal)
+ logger.info(f"Results serialized to {filename=}.")
+
+ def _run_phase(
+ self,
+ *,
+ model_cls: ModelMixin,
+ init_fn: Callable,
+ init_kwargs: Dict[str, Any],
+ get_input_fn: Callable,
+ compile_kwargs: Optional[Dict[str, Any]],
+ ) -> Dict[str, float]:
+ # setup
+ self.pre_benchmark()
+
+ # init & (optional) compile
+ model = init_fn(model_cls, **init_kwargs)
+ if compile_kwargs:
+ model.compile(**compile_kwargs)
+
+ # build inputs
+ inp = get_input_fn()
+
+ # measure
+ run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
+ with run_ctx:
+ for _ in range(NUM_WARMUP_ROUNDS):
+ _ = model(**inp)
+ time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
+ mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
+ mem_gb = round(mem_gb, 2)
+
+ # teardown
+ self.post_benchmark(model)
+ del model
+ return {"time": time_s, "memory": mem_gb}
diff --git a/benchmarks/benchmarking_wan.py b/benchmarks/benchmarking_wan.py
new file mode 100644
index 0000000000..64e81fdb6b
--- /dev/null
+++ b/benchmarks/benchmarking_wan.py
@@ -0,0 +1,74 @@
+from functools import partial
+
+import torch
+from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
+
+from diffusers import WanTransformer3DModel
+from diffusers.utils.testing_utils import torch_device
+
+
+CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
+RESULT_FILENAME = "wan.csv"
+
+
+def get_input_dict(**device_dtype_kwargs):
+ # height: 480
+ # width: 832
+ # num_frames: 81
+ # max_sequence_length: 512
+ hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
+ encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
+ timestep = torch.tensor([1.0], **device_dtype_kwargs)
+
+ return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
+
+
+if __name__ == "__main__":
+ scenarios = [
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-bf16",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=model_init_fn,
+ compile_kwargs={"fullgraph": True},
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-layerwise-upcasting",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
+ ),
+ BenchmarkScenario(
+ name=f"{CKPT_ID}-group-offload-leaf",
+ model_cls=WanTransformer3DModel,
+ model_init_kwargs={
+ "pretrained_model_name_or_path": CKPT_ID,
+ "torch_dtype": torch.bfloat16,
+ "subfolder": "transformer",
+ },
+ get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
+ model_init_fn=partial(
+ model_init_fn,
+ group_offload_kwargs={
+ "onload_device": torch_device,
+ "offload_device": torch.device("cpu"),
+ "offload_type": "leaf_level",
+ "use_stream": True,
+ "non_blocking": True,
+ },
+ ),
+ ),
+ ]
+
+ runner = BenchmarkMixin()
+ runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
diff --git a/benchmarks/populate_into_db.py b/benchmarks/populate_into_db.py
new file mode 100644
index 0000000000..55e46b0586
--- /dev/null
+++ b/benchmarks/populate_into_db.py
@@ -0,0 +1,166 @@
+import argparse
+import os
+import sys
+
+import gpustat
+import pandas as pd
+import psycopg2
+import psycopg2.extras
+from psycopg2.extensions import register_adapter
+from psycopg2.extras import Json
+
+
+register_adapter(dict, Json)
+
+FINAL_CSV_FILENAME = "collated_results.csv"
+# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
+BENCHMARKS_TABLE_NAME = "benchmarks"
+MEASUREMENTS_TABLE_NAME = "model_measurements"
+
+
+def _init_benchmark(conn, branch, commit_id, commit_msg):
+ gpu_stats = gpustat.GPUStatCollection.new_query()
+ metadata = {"gpu_name": gpu_stats[0]["name"]}
+ repository = "huggingface/diffusers"
+ with conn.cursor() as cur:
+ cur.execute(
+ f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
+ (repository, branch, commit_id, commit_msg, metadata),
+ )
+ benchmark_id = cur.fetchone()[0]
+ print(f"Initialised benchmark #{benchmark_id}")
+ return benchmark_id
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "branch",
+ type=str,
+ help="The branch name on which the benchmarking is performed.",
+ )
+
+ parser.add_argument(
+ "commit_id",
+ type=str,
+ help="The commit hash on which the benchmarking is performed.",
+ )
+
+ parser.add_argument(
+ "commit_msg",
+ type=str,
+ help="The commit message associated with the commit, truncated to 70 characters.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ try:
+ conn = psycopg2.connect(
+ host=os.getenv("PGHOST"),
+ database=os.getenv("PGDATABASE"),
+ user=os.getenv("PGUSER"),
+ password=os.getenv("PGPASSWORD"),
+ )
+ print("DB connection established successfully.")
+ except Exception as e:
+ print(f"Problem during DB init: {e}")
+ sys.exit(1)
+
+ try:
+ benchmark_id = _init_benchmark(
+ conn=conn,
+ branch=args.branch,
+ commit_id=args.commit_id,
+ commit_msg=args.commit_msg,
+ )
+ except Exception as e:
+ print(f"Problem during initializing benchmark: {e}")
+ sys.exit(1)
+
+ cur = conn.cursor()
+
+ df = pd.read_csv(FINAL_CSV_FILENAME)
+
+ # Helper to cast values (or None) given a dtype
+ def _cast_value(val, dtype: str):
+ if pd.isna(val):
+ return None
+
+ if dtype == "text":
+ return str(val).strip()
+
+ if dtype == "float":
+ try:
+ return float(val)
+ except ValueError:
+ return None
+
+ if dtype == "bool":
+ s = str(val).strip().lower()
+ if s in ("true", "t", "yes", "1"):
+ return True
+ if s in ("false", "f", "no", "0"):
+ return False
+ if val in (1, 1.0):
+ return True
+ if val in (0, 0.0):
+ return False
+ return None
+
+ return val
+
+ try:
+ rows_to_insert = []
+ for _, row in df.iterrows():
+ scenario = _cast_value(row.get("scenario"), "text")
+ model_cls = _cast_value(row.get("model_cls"), "text")
+ num_params_B = _cast_value(row.get("num_params_B"), "float")
+ flops_G = _cast_value(row.get("flops_G"), "float")
+ time_plain_s = _cast_value(row.get("time_plain_s"), "float")
+ mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
+ time_compile_s = _cast_value(row.get("time_compile_s"), "float")
+ mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
+ fullgraph = _cast_value(row.get("fullgraph"), "bool")
+ mode = _cast_value(row.get("mode"), "text")
+
+ # If "github_sha" column exists in the CSV, cast it; else default to None
+ if "github_sha" in df.columns:
+ github_sha = _cast_value(row.get("github_sha"), "text")
+ else:
+ github_sha = None
+
+ measurements = {
+ "scenario": scenario,
+ "model_cls": model_cls,
+ "num_params_B": num_params_B,
+ "flops_G": flops_G,
+ "time_plain_s": time_plain_s,
+ "mem_plain_GB": mem_plain_GB,
+ "time_compile_s": time_compile_s,
+ "mem_compile_GB": mem_compile_GB,
+ "fullgraph": fullgraph,
+ "mode": mode,
+ "github_sha": github_sha,
+ }
+ rows_to_insert.append((benchmark_id, measurements))
+
+ # Batch-insert all rows
+ insert_sql = f"""
+ INSERT INTO {MEASUREMENTS_TABLE_NAME} (
+ benchmark_id,
+ measurements
+ )
+ VALUES (%s, %s);
+ """
+
+ psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
+ conn.commit()
+
+ cur.close()
+ conn.close()
+ except Exception as e:
+ print(f"Exception: {e}")
+ sys.exit(1)
diff --git a/benchmarks/push_results.py b/benchmarks/push_results.py
index 71cd60f32c..8be3b39368 100644
--- a/benchmarks/push_results.py
+++ b/benchmarks/push_results.py
@@ -1,19 +1,19 @@
-import glob
-import sys
+import os
import pandas as pd
from huggingface_hub import hf_hub_download, upload_file
from huggingface_hub.utils import EntryNotFoundError
-sys.path.append(".")
-from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
+REPO_ID = "diffusers/benchmarks"
def has_previous_benchmark() -> str:
+ from run_all import FINAL_CSV_FILENAME
+
csv_path = None
try:
- csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
+ csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
except EntryNotFoundError:
csv_path = None
return csv_path
@@ -26,46 +26,50 @@ def filter_float(value):
def push_to_hf_dataset():
- all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
- collate_csv(all_csvs, FINAL_CSV_FILE)
+ from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
- # If there's an existing benchmark file, we should report the changes.
csv_path = has_previous_benchmark()
if csv_path is not None:
- current_results = pd.read_csv(FINAL_CSV_FILE)
+ current_results = pd.read_csv(FINAL_CSV_FILENAME)
previous_results = pd.read_csv(csv_path)
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
- numeric_columns = [
- c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
- ]
for column in numeric_columns:
- previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
+ # get previous values as floats, aligned to current index
+ prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
- # Calculate the percentage change
- current_results[column] = current_results[column].astype(float)
- previous_results[column] = previous_results[column].astype(float)
- percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
+ # get current values as floats
+ curr_vals = current_results[column].astype(float)
- # Format the values with '+' or '-' sign and append to original values
- current_results[column] = current_results[column].map(str) + percent_change.map(
- lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
+ # stringify the current values
+ curr_str = curr_vals.map(str)
+
+ # build an appendage only when prev exists and differs
+ append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
+ lambda x: f" ({x})" if pd.notnull(x) else ""
)
- # There might be newly added rows. So, filter out the NaNs.
- current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
- # Overwrite the current result file.
- current_results.to_csv(FINAL_CSV_FILE, index=False)
+ # combine
+ current_results[column] = curr_str + append_str
+ os.remove(FINAL_CSV_FILENAME)
+ current_results.to_csv(FINAL_CSV_FILENAME, index=False)
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
upload_file(
repo_id=REPO_ID,
- path_in_repo=FINAL_CSV_FILE,
- path_or_fileobj=FINAL_CSV_FILE,
+ path_in_repo=FINAL_CSV_FILENAME,
+ path_or_fileobj=FINAL_CSV_FILENAME,
repo_type="dataset",
commit_message=commit_message,
)
+ upload_file(
+ repo_id="diffusers/benchmark-analyzer",
+ path_in_repo=FINAL_CSV_FILENAME,
+ path_or_fileobj=FINAL_CSV_FILENAME,
+ repo_type="space",
+ commit_message=commit_message,
+ )
if __name__ == "__main__":
diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt
new file mode 100644
index 0000000000..1f47ecc6ca
--- /dev/null
+++ b/benchmarks/requirements.txt
@@ -0,0 +1,6 @@
+pandas
+psutil
+gpustat
+torchprofile
+bitsandbytes
+psycopg2==2.9.9
\ No newline at end of file
diff --git a/benchmarks/run_all.py b/benchmarks/run_all.py
index c9932cc71c..9cf053f548 100644
--- a/benchmarks/run_all.py
+++ b/benchmarks/run_all.py
@@ -1,101 +1,84 @@
import glob
+import logging
+import os
import subprocess
-import sys
-from typing import List
+
+import pandas as pd
-sys.path.append(".")
-from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
-
-PATTERN = "benchmark_*.py"
+PATTERN = "benchmarking_*.py"
+FINAL_CSV_FILENAME = "collated_results.csv"
+GITHUB_SHA = os.getenv("GITHUB_SHA", None)
class SubprocessCallException(Exception):
pass
-# Taken from `test_examples_utils.py`
-def run_command(command: List[str], return_stdout=False):
- """
- Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
- if an error occurred while running `command`
- """
+def run_command(command: list[str], return_stdout=False):
try:
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
- if return_stdout:
- if hasattr(output, "decode"):
- output = output.decode("utf-8")
- return output
+ if return_stdout and hasattr(output, "decode"):
+ return output.decode("utf-8")
except subprocess.CalledProcessError as e:
- raise SubprocessCallException(
- f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
- ) from e
+ raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
-def main():
- python_files = glob.glob(PATTERN)
+def merge_csvs(final_csv: str = "collated_results.csv"):
+ all_csvs = glob.glob("*.csv")
+ all_csvs = [f for f in all_csvs if f != final_csv]
+ if not all_csvs:
+ logger.info("No result CSVs found to merge.")
+ return
- for file in python_files:
- print(f"****** Running file: {file} ******")
-
- # Run with canonical settings.
- if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
- command = f"python {file}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
-
- # Run variants.
- for file in python_files:
- # See: https://github.com/pytorch/pytorch/issues/129637
- if file == "benchmark_ip_adapters.py":
+ df_list = []
+ for f in all_csvs:
+ try:
+ d = pd.read_csv(f)
+ except pd.errors.EmptyDataError:
+ # If a file existed but was zero‐bytes or corrupted, skip it
continue
+ df_list.append(d)
- if file == "benchmark_text_to_image.py":
- for ckpt in ALL_T2I_CKPTS:
- command = f"python {file} --ckpt {ckpt}"
+ if not df_list:
+ logger.info("All result CSVs were empty or invalid; nothing to merge.")
+ return
- if "turbo" in ckpt:
- command += " --num_inference_steps 1"
+ final_df = pd.concat(df_list, ignore_index=True)
+ if GITHUB_SHA is not None:
+ final_df["github_sha"] = GITHUB_SHA
+ final_df.to_csv(final_csv, index=False)
+ logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
- run_command(command.split())
- command += " --run_compile"
- run_command(command.split())
+def run_scripts():
+ python_files = sorted(glob.glob(PATTERN))
+ python_files = [f for f in python_files if f != "benchmarking_utils.py"]
- elif file == "benchmark_sd_img.py":
- for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
- command = f"python {file} --ckpt {ckpt}"
+ for file in python_files:
+ script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
+ logger.info(f"\n****** Running file: {file} ******")
- if ckpt == "stabilityai/sdxl-turbo":
- command += " --num_inference_steps 2"
+ partial_csv = f"{script_name}.csv"
+ if os.path.exists(partial_csv):
+ logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
+ os.remove(partial_csv)
- run_command(command.split())
- command += " --run_compile"
- run_command(command.split())
+ command = ["python", file]
+ try:
+ run_command(command)
+ logger.info(f"→ {file} finished normally.")
+ except SubprocessCallException as e:
+ logger.info(f"Error running {file}:\n{e}")
+ finally:
+ logger.info(f"→ Merging partial CSVs after {file} …")
+ merge_csvs(final_csv=FINAL_CSV_FILENAME)
- elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
- sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
- command = f"python {file} --ckpt {sdxl_ckpt}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
-
- elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
- sdxl_ckpt = (
- "diffusers/controlnet-canny-sdxl-1.0"
- if "controlnet" in file
- else "TencentARC/t2i-adapter-canny-sdxl-1.0"
- )
- command = f"python {file} --ckpt {sdxl_ckpt}"
- run_command(command.split())
-
- command += " --run_compile"
- run_command(command.split())
+ logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
if __name__ == "__main__":
- main()
+ run_scripts()
diff --git a/benchmarks/utils.py b/benchmarks/utils.py
deleted file mode 100644
index 5fce920ac6..0000000000
--- a/benchmarks/utils.py
+++ /dev/null
@@ -1,98 +0,0 @@
-import argparse
-import csv
-import gc
-import os
-from dataclasses import dataclass
-from typing import Dict, List, Union
-
-import torch
-import torch.utils.benchmark as benchmark
-
-
-GITHUB_SHA = os.getenv("GITHUB_SHA", None)
-BENCHMARK_FIELDS = [
- "pipeline_cls",
- "ckpt_id",
- "batch_size",
- "num_inference_steps",
- "model_cpu_offload",
- "run_compile",
- "time (secs)",
- "memory (gbs)",
- "actual_gpu_memory (gbs)",
- "github_sha",
-]
-
-PROMPT = "ghibli style, a fantasy landscape with castles"
-BASE_PATH = os.getenv("BASE_PATH", ".")
-TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
-
-REPO_ID = "diffusers/benchmarks"
-FINAL_CSV_FILE = "collated_results.csv"
-
-
-@dataclass
-class BenchmarkInfo:
- time: float
- memory: float
-
-
-def flush():
- """Wipes off memory."""
- gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.reset_max_memory_allocated()
- torch.cuda.reset_peak_memory_stats()
-
-
-def bytes_to_giga_bytes(bytes):
- return f"{(bytes / 1024 / 1024 / 1024):.3f}"
-
-
-def benchmark_fn(f, *args, **kwargs):
- t0 = benchmark.Timer(
- stmt="f(*args, **kwargs)",
- globals={"args": args, "kwargs": kwargs, "f": f},
- num_threads=torch.get_num_threads(),
- )
- return f"{(t0.blocked_autorange().mean):.3f}"
-
-
-def generate_csv_dict(
- pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
-) -> Dict[str, Union[str, bool, float]]:
- """Packs benchmarking data into a dictionary for latter serialization."""
- data_dict = {
- "pipeline_cls": pipeline_cls,
- "ckpt_id": ckpt,
- "batch_size": args.batch_size,
- "num_inference_steps": args.num_inference_steps,
- "model_cpu_offload": args.model_cpu_offload,
- "run_compile": args.run_compile,
- "time (secs)": benchmark_info.time,
- "memory (gbs)": benchmark_info.memory,
- "actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
- "github_sha": GITHUB_SHA,
- }
- return data_dict
-
-
-def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
- """Serializes a dictionary into a CSV file."""
- with open(file_name, mode="w", newline="") as csvfile:
- writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
- writer.writeheader()
- writer.writerow(data_dict)
-
-
-def collate_csv(input_files: List[str], output_file: str):
- """Collates multiple identically structured CSVs into a single CSV file."""
- with open(output_file, mode="w", newline="") as outfile:
- writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
- writer.writeheader()
-
- for file in input_files:
- with open(file, mode="r") as infile:
- reader = csv.DictReader(infile)
- for row in reader:
- writer.writerow(row)
diff --git a/docker/diffusers-doc-builder/Dockerfile b/docker/diffusers-doc-builder/Dockerfile
index c9fc62707c..3a76b3331c 100644
--- a/docker/diffusers-doc-builder/Dockerfile
+++ b/docker/diffusers-doc-builder/Dockerfile
@@ -47,6 +47,10 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
tensorboard \
transformers \
matplotlib \
- setuptools==69.5.1
+ setuptools==69.5.1 \
+ bitsandbytes \
+ torchao \
+ gguf \
+ optimum-quanto
CMD ["/bin/bash"]
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 283efeef72..b959831111 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -1,36 +1,39 @@
-- sections:
+- title: Get started
+ sections:
- local: index
- title: 🧨 Diffusers
+ title: Diffusers
+ - local: installation
+ title: Installation
- local: quicktour
title: Quicktour
- local: stable_diffusion
title: Effective and efficient diffusion
- - local: installation
- title: Installation
- title: Get started
-- sections:
- - local: tutorials/tutorial_overview
- title: Overview
- - local: using-diffusers/write_own_pipeline
- title: Understanding pipelines, models and schedulers
- - local: tutorials/autopipeline
- title: AutoPipeline
- - local: tutorials/basic_training
- title: Train a diffusion model
- title: Tutorials
-- sections:
+
+- title: DiffusionPipeline
+ isExpanded: false
+ sections:
- local: using-diffusers/loading
title: Load pipelines
+ - local: tutorials/autopipeline
+ title: AutoPipeline
- local: using-diffusers/custom_pipeline_overview
title: Load community pipelines and components
+ - local: using-diffusers/callback
+ title: Pipeline callbacks
+ - local: using-diffusers/reusing_seeds
+ title: Reproducible pipelines
- local: using-diffusers/schedulers
title: Load schedulers and models
+ - local: using-diffusers/scheduler_features
+ title: Scheduler features
- local: using-diffusers/other-formats
title: Model files and layouts
- local: using-diffusers/push_to_hub
title: Push files to the Hub
- title: Load pipelines and adapters
-- sections:
+
+- title: Adapters
+ isExpanded: false
+ sections:
- local: tutorials/using_peft_for_inference
title: LoRA
- local: using-diffusers/ip_adapter
@@ -43,27 +46,16 @@
title: DreamBooth
- local: using-diffusers/textual_inversion_inference
title: Textual inversion
- title: Adapters
+
+- title: Inference
isExpanded: false
-- sections:
- - local: using-diffusers/unconditional_image_generation
- title: Unconditional image generation
- - local: using-diffusers/conditional_image_generation
- title: Text-to-image
- - local: using-diffusers/img2img
- title: Image-to-image
- - local: using-diffusers/inpaint
- title: Inpainting
- - local: using-diffusers/text-img2vid
- title: Video generation
- - local: using-diffusers/depth2img
- title: Depth-to-image
- title: Generative tasks
-- sections:
- - local: using-diffusers/overview_techniques
- title: Overview
+ sections:
+ - local: using-diffusers/weighted_prompts
+ title: Prompt techniques
- local: using-diffusers/create_a_server
title: Create a server
+ - local: using-diffusers/batched_inference
+ title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- local: using-diffusers/scheduler_features
@@ -74,14 +66,38 @@
title: Reproducible pipelines
- local: using-diffusers/image_quality
title: Controlling image quality
- - local: using-diffusers/weighted_prompts
- title: Prompt techniques
- title: Inference techniques
-- sections:
- - local: advanced_inference/outpaint
- title: Outpainting
- title: Advanced inference
-- sections:
+
+- title: Inference optimization
+ isExpanded: false
+ sections:
+ - local: optimization/fp16
+ title: Accelerate inference
+ - local: optimization/cache
+ title: Caching
+ - local: optimization/memory
+ title: Reduce memory usage
+ - local: optimization/speed-memory-optims
+ title: Compile and offloading quantized models
+ - title: Community optimizations
+ sections:
+ - local: optimization/pruna
+ title: Pruna
+ - local: optimization/xformers
+ title: xFormers
+ - local: optimization/tome
+ title: Token merging
+ - local: optimization/deepcache
+ title: DeepCache
+ - local: optimization/tgate
+ title: TGATE
+ - local: optimization/xdit
+ title: xDiT
+ - local: optimization/para_attn
+ title: ParaAttention
+
+- title: Hybrid Inference
+ isExpanded: false
+ sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
@@ -90,8 +106,110 @@
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
- title: Hybrid Inference
-- sections:
+
+- title: Modular Diffusers
+ isExpanded: false
+ sections:
+ - local: modular_diffusers/overview
+ title: Overview
+ - local: modular_diffusers/modular_pipeline
+ title: Modular Pipeline
+ - local: modular_diffusers/components_manager
+ title: Components Manager
+ - local: modular_diffusers/modular_diffusers_states
+ title: Modular Diffusers States
+ - local: modular_diffusers/pipeline_block
+ title: Pipeline Block
+ - local: modular_diffusers/sequential_pipeline_blocks
+ title: Sequential Pipeline Blocks
+ - local: modular_diffusers/loop_sequential_pipeline_blocks
+ title: Loop Sequential Pipeline Blocks
+ - local: modular_diffusers/auto_pipeline_blocks
+ title: Auto Pipeline Blocks
+ - local: modular_diffusers/end_to_end_guide
+ title: End-to-End Example
+
+- title: Training
+ isExpanded: false
+ sections:
+ - local: training/overview
+ title: Overview
+ - local: training/create_dataset
+ title: Create a dataset for training
+ - local: training/adapt_a_model
+ title: Adapt a model to a new task
+ - local: tutorials/basic_training
+ title: Train a diffusion model
+ - title: Models
+ sections:
+ - local: training/unconditional_training
+ title: Unconditional image generation
+ - local: training/text2image
+ title: Text-to-image
+ - local: training/sdxl
+ title: Stable Diffusion XL
+ - local: training/kandinsky
+ title: Kandinsky 2.2
+ - local: training/wuerstchen
+ title: Wuerstchen
+ - local: training/controlnet
+ title: ControlNet
+ - local: training/t2i_adapters
+ title: T2I-Adapters
+ - local: training/instructpix2pix
+ title: InstructPix2Pix
+ - local: training/cogvideox
+ title: CogVideoX
+ - title: Methods
+ sections:
+ - local: training/text_inversion
+ title: Textual Inversion
+ - local: training/dreambooth
+ title: DreamBooth
+ - local: training/lora
+ title: LoRA
+ - local: training/custom_diffusion
+ title: Custom Diffusion
+ - local: training/lcm_distill
+ title: Latent Consistency Distillation
+ - local: training/ddpo
+ title: Reinforcement learning training with DDPO
+
+- title: Quantization
+ isExpanded: false
+ sections:
+ - local: quantization/overview
+ title: Getting started
+ - local: quantization/bitsandbytes
+ title: bitsandbytes
+ - local: quantization/gguf
+ title: gguf
+ - local: quantization/torchao
+ title: torchao
+ - local: quantization/quanto
+ title: quanto
+
+- title: Model accelerators and hardware
+ isExpanded: false
+ sections:
+ - local: using-diffusers/stable_diffusion_jax_how_to
+ title: JAX/Flax
+ - local: optimization/onnx
+ title: ONNX
+ - local: optimization/open_vino
+ title: OpenVINO
+ - local: optimization/coreml
+ title: Core ML
+ - local: optimization/mps
+ title: Metal Performance Shaders (MPS)
+ - local: optimization/habana
+ title: Intel Gaudi
+ - local: optimization/neuron
+ title: AWS Neuron
+
+- title: Specific pipeline examples
+ isExpanded: false
+ sections:
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -116,106 +234,30 @@
title: Stable Video Diffusion
- local: using-diffusers/marigold_usage
title: Marigold Computer Vision
- title: Specific pipeline examples
-- sections:
- - local: training/overview
- title: Overview
- - local: training/create_dataset
- title: Create a dataset for training
- - local: training/adapt_a_model
- title: Adapt a model to a new task
- - isExpanded: false
+
+- title: Resources
+ isExpanded: false
+ sections:
+ - title: Task recipes
sections:
- - local: training/unconditional_training
+ - local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- - local: training/text2image
+ - local: using-diffusers/conditional_image_generation
title: Text-to-image
- - local: training/sdxl
- title: Stable Diffusion XL
- - local: training/kandinsky
- title: Kandinsky 2.2
- - local: training/wuerstchen
- title: Wuerstchen
- - local: training/controlnet
- title: ControlNet
- - local: training/t2i_adapters
- title: T2I-Adapters
- - local: training/instructpix2pix
- title: InstructPix2Pix
- - local: training/cogvideox
- title: CogVideoX
- title: Models
- - isExpanded: false
- sections:
- - local: training/text_inversion
- title: Textual Inversion
- - local: training/dreambooth
- title: DreamBooth
- - local: training/lora
- title: LoRA
- - local: training/custom_diffusion
- title: Custom Diffusion
- - local: training/lcm_distill
- title: Latent Consistency Distillation
- - local: training/ddpo
- title: Reinforcement learning training with DDPO
- title: Methods
- title: Training
-- sections:
- - local: quantization/overview
- title: Getting Started
- - local: quantization/bitsandbytes
- title: bitsandbytes
- - local: quantization/gguf
- title: gguf
- - local: quantization/torchao
- title: torchao
- - local: quantization/quanto
- title: quanto
- title: Quantization Methods
-- sections:
- - local: optimization/fp16
- title: Accelerate inference
- - local: optimization/cache
- title: Caching
- - local: optimization/memory
- title: Reduce memory usage
- - local: optimization/speed-memory-optims
- title: Compile and offloading quantized models
- - local: optimization/pruna
- title: Pruna
- - local: optimization/xformers
- title: xFormers
- - local: optimization/tome
- title: Token merging
- - local: optimization/deepcache
- title: DeepCache
- - local: optimization/tgate
- title: TGATE
- - local: optimization/xdit
- title: xDiT
- - local: optimization/para_attn
- title: ParaAttention
- - sections:
- - local: using-diffusers/stable_diffusion_jax_how_to
- title: JAX/Flax
- - local: optimization/onnx
- title: ONNX
- - local: optimization/open_vino
- title: OpenVINO
- - local: optimization/coreml
- title: Core ML
- title: Optimized model formats
- - sections:
- - local: optimization/mps
- title: Metal Performance Shaders (MPS)
- - local: optimization/habana
- title: Intel Gaudi
- - local: optimization/neuron
- title: AWS Neuron
- title: Optimized hardware
- title: Accelerate inference and reduce memory
-- sections:
+ - local: using-diffusers/img2img
+ title: Image-to-image
+ - local: using-diffusers/inpaint
+ title: Inpainting
+ - local: advanced_inference/outpaint
+ title: Outpainting
+ - local: using-diffusers/text-img2vid
+ title: Video generation
+ - local: using-diffusers/depth2img
+ title: Depth-to-image
+ - local: using-diffusers/write_own_pipeline
+ title: Understanding pipelines, models and schedulers
+ - local: community_projects
+ title: Projects built with Diffusers
- local: conceptual/philosophy
title: Philosophy
- local: using-diffusers/controlling_generation
@@ -226,13 +268,11 @@
title: Diffusers' Ethical Guidelines
- local: conceptual/evaluation
title: Evaluating Diffusion Models
- title: Conceptual Guides
-- sections:
- - local: community_projects
- title: Projects built with Diffusers
- title: Community Projects
-- sections:
- - isExpanded: false
+
+- title: API
+ isExpanded: false
+ sections:
+ - title: Main Classes
sections:
- local: api/configuration
title: Configuration
@@ -242,8 +282,7 @@
title: Outputs
- local: api/quantization
title: Quantization
- title: Main Classes
- - isExpanded: false
+ - title: Loaders
sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
@@ -259,14 +298,14 @@
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
- title: Loaders
- - isExpanded: false
+ - title: Models
sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- - sections:
+ - title: ControlNets
+ sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
@@ -281,8 +320,8 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- title: ControlNets
- - sections:
+ - title: Transformers
+ sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
@@ -331,6 +370,8 @@
title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d
title: SD3Transformer2DModel
+ - local: api/models/skyreels_v2_transformer_3d
+ title: SkyReelsV2Transformer3DModel
- local: api/models/stable_audio_transformer
title: StableAudioDiTModel
- local: api/models/transformer2d
@@ -339,8 +380,8 @@
title: TransformerTemporalModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- title: Transformers
- - sections:
+ - title: UNets
+ sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
@@ -355,8 +396,8 @@
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
- title: UNets
- - sections:
+ - title: VAEs
+ sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -387,9 +428,7 @@
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
- title: VAEs
- title: Models
- - isExpanded: false
+ - title: Pipelines
sections:
- local: api/pipelines/overview
title: Overview
@@ -525,11 +564,14 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
+ - local: api/pipelines/skyreels_v2
+ title: SkyReels-V2
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- - sections:
+ - title: Stable Diffusion
+ sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
@@ -566,7 +608,6 @@
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
- title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/text_to_video
@@ -585,8 +626,7 @@
title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
- title: Pipelines
- - isExpanded: false
+ - title: Schedulers
sections:
- local: api/schedulers/overview
title: Overview
@@ -656,8 +696,7 @@
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- title: Schedulers
- - isExpanded: false
+ - title: Internal classes
sections:
- local: api/internal_classes_overview
title: Overview
@@ -675,5 +714,3 @@
title: VAE Image Processor
- local: api/video_processor
title: Video Processor
- title: Internal classes
- title: API
diff --git a/docs/source/en/api/cache.md b/docs/source/en/api/cache.md
index e90cb32c54..9ba4742085 100644
--- a/docs/source/en/api/cache.md
+++ b/docs/source/en/api/cache.md
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] FasterCacheConfig
[[autodoc]] apply_faster_cache
+
+### FirstBlockCacheConfig
+
+[[autodoc]] FirstBlockCacheConfig
+
+[[autodoc]] apply_first_block_cache
diff --git a/docs/source/en/api/configuration.md b/docs/source/en/api/configuration.md
index 46d9ede0c9..bc58e190b8 100644
--- a/docs/source/en/api/configuration.md
+++ b/docs/source/en/api/configuration.md
@@ -16,7 +16,7 @@ Schedulers from [`~schedulers.scheduling_utils.SchedulerMixin`] and models from
-To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `huggingface-cli login`.
+To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf auth login`.
diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md
index 574b8499e1..20b5fcb88a 100644
--- a/docs/source/en/api/loaders/lora.md
+++ b/docs/source/en/api/loaders/lora.md
@@ -26,6 +26,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
+- [`SkyReelsV2LoraLoaderMixin`] provides similar functions for [SkyReels-V2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/skyreels_v2).
- [`CogView4LoraLoaderMixin`] provides similar functions for [CogView4](https://huggingface.co/docs/diffusers/main/en/api/pipelines/cogview4).
- [`AmusedLoraLoaderMixin`] is for the [`AmusedPipeline`].
- [`HiDreamImageLoraLoaderMixin`] provides similar functions for [HiDream Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hidream)
@@ -92,6 +93,10 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
+## SkyReelsV2LoraLoaderMixin
+
+[[autodoc]] loaders.lora_pipeline.SkyReelsV2LoraLoaderMixin
+
## AmusedLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
@@ -100,6 +105,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
[[autodoc]] loaders.lora_pipeline.HiDreamImageLoraLoaderMixin
-## WanLoraLoaderMixin
+## LoraBaseMixin
-[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
\ No newline at end of file
+[[autodoc]] loaders.lora_base.LoraBaseMixin
\ No newline at end of file
diff --git a/docs/source/en/api/models/skyreels_v2_transformer_3d.md b/docs/source/en/api/models/skyreels_v2_transformer_3d.md
new file mode 100644
index 0000000000..c1c8c2c7bc
--- /dev/null
+++ b/docs/source/en/api/models/skyreels_v2_transformer_3d.md
@@ -0,0 +1,30 @@
+
+
+# SkyReelsV2Transformer3DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [SkyReels-V2](https://github.com/SkyworkAI/SkyReels-V2) by the Skywork AI.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import SkyReelsV2Transformer3DModel
+
+transformer = SkyReelsV2Transformer3DModel.from_pretrained("Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## SkyReelsV2Transformer3DModel
+
+[[autodoc]] SkyReelsV2Transformer3DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md
index eb78c8b704..ad292abca2 100644
--- a/docs/source/en/api/pipelines/amused.md
+++ b/docs/source/en/api/pipelines/amused.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# aMUSEd
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
diff --git a/docs/source/en/api/pipelines/attend_and_excite.md b/docs/source/en/api/pipelines/attend_and_excite.md
index ca0aa7af98..b5ce3bb767 100644
--- a/docs/source/en/api/pipelines/attend_and_excite.md
+++ b/docs/source/en/api/pipelines/attend_and_excite.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Attend-and-Excite
Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
diff --git a/docs/source/en/api/pipelines/audioldm.md b/docs/source/en/api/pipelines/audioldm.md
index a5ef9c4872..6b143d2990 100644
--- a/docs/source/en/api/pipelines/audioldm.md
+++ b/docs/source/en/api/pipelines/audioldm.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# AudioLDM
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
diff --git a/docs/source/en/api/pipelines/blip_diffusion.md b/docs/source/en/api/pipelines/blip_diffusion.md
index c13288d489..d94281a4a9 100644
--- a/docs/source/en/api/pipelines/blip_diffusion.md
+++ b/docs/source/en/api/pipelines/blip_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# BLIP-Diffusion
BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
diff --git a/docs/source/en/api/pipelines/chroma.md b/docs/source/en/api/pipelines/chroma.md
index 4e2d144421..40e290e4bd 100644
--- a/docs/source/en/api/pipelines/chroma.md
+++ b/docs/source/en/api/pipelines/chroma.md
@@ -36,7 +36,7 @@ import torch
from diffusers import ChromaPipeline
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
-pipe.enabe_model_cpu_offload()
+pipe.enable_model_cpu_offload()
prompt = [
"A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
diff --git a/docs/source/en/api/pipelines/controlnetxs.md b/docs/source/en/api/pipelines/controlnetxs.md
index 2eebcc6b74..aea8cb2e86 100644
--- a/docs/source/en/api/pipelines/controlnetxs.md
+++ b/docs/source/en/api/pipelines/controlnetxs.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS
diff --git a/docs/source/en/api/pipelines/controlnetxs_sdxl.md b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
index 0862a5d798..76937b16c5 100644
--- a/docs/source/en/api/pipelines/controlnetxs_sdxl.md
+++ b/docs/source/en/api/pipelines/controlnetxs_sdxl.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# ControlNet-XS with Stable Diffusion XL
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
diff --git a/docs/source/en/api/pipelines/cosmos.md b/docs/source/en/api/pipelines/cosmos.md
index 99deef37e1..dba807c5ce 100644
--- a/docs/source/en/api/pipelines/cosmos.md
+++ b/docs/source/en/api/pipelines/cosmos.md
@@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
+## Loading original format checkpoints
+
+Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
+
+```python
+import torch
+from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
+
+model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
+transformer = CosmosTransformer3DModel.from_single_file(
+ "https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
+pipe.to("cuda")
+
+prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
+negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
+
+output = pipe(
+ prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
+).images[0]
+output.save("output.png")
+```
+
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
diff --git a/docs/source/en/api/pipelines/dance_diffusion.md b/docs/source/en/api/pipelines/dance_diffusion.md
index 64a738f17c..5805561e49 100644
--- a/docs/source/en/api/pipelines/dance_diffusion.md
+++ b/docs/source/en/api/pipelines/dance_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Dance Diffusion
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
diff --git a/docs/source/en/api/pipelines/diffedit.md b/docs/source/en/api/pipelines/diffedit.md
index 02a76cf589..9734ca2eab 100644
--- a/docs/source/en/api/pipelines/diffedit.md
+++ b/docs/source/en/api/pipelines/diffedit.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# DiffEdit
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
diff --git a/docs/source/en/api/pipelines/flux.md b/docs/source/en/api/pipelines/flux.md
index ef29e77ce2..ca39d71814 100644
--- a/docs/source/en/api/pipelines/flux.md
+++ b/docs/source/en/api/pipelines/flux.md
@@ -39,6 +39,7 @@ Flux comes in the following variants:
| Canny Control (LoRA) | [`black-forest-labs/FLUX.1-Canny-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora) |
| Depth Control (LoRA) | [`black-forest-labs/FLUX.1-Depth-dev-lora`](https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora) |
| Redux (Adapter) | [`black-forest-labs/FLUX.1-Redux-dev`](https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev) |
+| Kontext | [`black-forest-labs/FLUX.1-kontext`](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) |
All checkpoints have different usage which we detail below.
@@ -273,6 +274,46 @@ images = pipe(
images[0].save("flux-redux.png")
```
+### Kontext
+
+Flux Kontext is a model that allows in-context control of the image generation process, allowing for editing, refinement, relighting, style transfer, character customization, and more.
+
+```python
+import torch
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+pipe = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+)
+pipe.to("cuda")
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png").convert("RGB")
+prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
+image = pipe(
+ image=image,
+ prompt=prompt,
+ guidance_scale=2.5,
+ generator=torch.Generator().manual_seed(42),
+).images[0]
+image.save("flux-kontext.png")
+```
+
+Flux Kontext comes with an integrity safety checker, which should be run after the image generation step. To run the safety checker, install the official repository from [black-forest-labs/flux](https://github.com/black-forest-labs/flux) and add the following code:
+
+```python
+from flux.content_filters import PixtralContentFilter
+
+# ... pipeline invocation to generate images
+
+integrity_checker = PixtralContentFilter(torch.device("cuda"))
+image_ = np.array(image) / 255.0
+image_ = 2 * image_ - 1
+image_ = torch.from_numpy(image_).to("cuda", dtype=torch.float32).unsqueeze(0).permute(0, 3, 1, 2)
+if integrity_checker.test_image(image_):
+ raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
+```
+
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
diff --git a/docs/source/en/api/pipelines/i2vgenxl.md b/docs/source/en/api/pipelines/i2vgenxl.md
index eea7eeab19..76a51a6cd5 100644
--- a/docs/source/en/api/pipelines/i2vgenxl.md
+++ b/docs/source/en/api/pipelines/i2vgenxl.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# I2VGen-XL
[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.
diff --git a/docs/source/en/api/pipelines/musicldm.md b/docs/source/en/api/pipelines/musicldm.md
index 5072bcc4fb..c2297162f7 100644
--- a/docs/source/en/api/pipelines/musicldm.md
+++ b/docs/source/en/api/pipelines/musicldm.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MusicLDM
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
diff --git a/docs/source/en/api/pipelines/paint_by_example.md b/docs/source/en/api/pipelines/paint_by_example.md
index 769156643b..362c26de68 100644
--- a/docs/source/en/api/pipelines/paint_by_example.md
+++ b/docs/source/en/api/pipelines/paint_by_example.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Paint by Example
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
diff --git a/docs/source/en/api/pipelines/panorama.md b/docs/source/en/api/pipelines/panorama.md
index a9a95759d6..9f61388dd5 100644
--- a/docs/source/en/api/pipelines/panorama.md
+++ b/docs/source/en/api/pipelines/panorama.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# MultiDiffusion
diff --git a/docs/source/en/api/pipelines/pia.md b/docs/source/en/api/pipelines/pia.md
index a58d7fbe8d..7bd480b49a 100644
--- a/docs/source/en/api/pipelines/pia.md
+++ b/docs/source/en/api/pipelines/pia.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Image-to-Video Generation with PIA (Personalized Image Animator)
diff --git a/docs/source/en/api/pipelines/self_attention_guidance.md b/docs/source/en/api/pipelines/self_attention_guidance.md
index f86cbc0b6f..5578fdfa63 100644
--- a/docs/source/en/api/pipelines/self_attention_guidance.md
+++ b/docs/source/en/api/pipelines/self_attention_guidance.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Self-Attention Guidance
[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
diff --git a/docs/source/en/api/pipelines/semantic_stable_diffusion.md b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
index 99395e75a9..1ce44cf2de 100644
--- a/docs/source/en/api/pipelines/semantic_stable_diffusion.md
+++ b/docs/source/en/api/pipelines/semantic_stable_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Semantic Guidance
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
diff --git a/docs/source/en/api/pipelines/skyreels_v2.md b/docs/source/en/api/pipelines/skyreels_v2.md
new file mode 100644
index 0000000000..cd94f2a75c
--- /dev/null
+++ b/docs/source/en/api/pipelines/skyreels_v2.md
@@ -0,0 +1,367 @@
+
+
+
+
+# SkyReels-V2: Infinite-length Film Generative model
+
+[SkyReels-V2](https://huggingface.co/papers/2504.13074) by the SkyReels Team.
+
+*Recent advances in video generation have been driven by diffusion models and autoregressive frameworks, yet critical challenges persist in harmonizing prompt adherence, visual quality, motion dynamics, and duration: compromises in motion dynamics to enhance temporal visual quality, constrained video duration (5-10 seconds) to prioritize resolution, and inadequate shot-aware generation stemming from general-purpose MLLMs' inability to interpret cinematic grammar, such as shot composition, actor expressions, and camera motions. These intertwined limitations hinder realistic long-form synthesis and professional film-style generation. To address these limitations, we propose SkyReels-V2, an Infinite-length Film Generative Model, that synergizes Multi-modal Large Language Model (MLLM), Multi-stage Pretraining, Reinforcement Learning, and Diffusion Forcing Framework. Firstly, we design a comprehensive structural representation of video that combines the general descriptions by the Multi-modal LLM and the detailed shot language by sub-expert models. Aided with human annotation, we then train a unified Video Captioner, named SkyCaptioner-V1, to efficiently label the video data. Secondly, we establish progressive-resolution pretraining for the fundamental video generation, followed by a four-stage post-training enhancement: Initial concept-balanced Supervised Fine-Tuning (SFT) improves baseline quality; Motion-specific Reinforcement Learning (RL) training with human-annotated and synthetic distortion data addresses dynamic artifacts; Our diffusion forcing framework with non-decreasing noise schedules enables long-video synthesis in an efficient search space; Final high-quality SFT refines visual fidelity. All the code and models are available at [this https URL](https://github.com/SkyworkAI/SkyReels-V2).*
+
+You can find all the original SkyReels-V2 checkpoints under the [Skywork](https://huggingface.co/collections/Skywork/skyreels-v2-6801b1b93df627d441d0d0d9) organization.
+
+The following SkyReels-V2 models are supported in Diffusers:
+- [SkyReels-V2 DF 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers)
+- [SkyReels-V2 DF 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P-Diffusers)
+- [SkyReels-V2 DF 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-720P-Diffusers)
+- [SkyReels-V2 T2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-540P-Diffusers)
+- [SkyReels-V2 T2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-T2V-14B-720P-Diffusers)
+- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
+- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
+- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
+- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
+
+> [!TIP]
+> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
+
+### A _Visual_ Demonstration
+
+ An example with these parameters:
+ base_num_frames=97, num_frames=97, num_inference_steps=30, ar_step=5, causal_block_size=5
+
+ vae_scale_factor_temporal -> 4
+ num_latent_frames: (97-1)//vae_scale_factor_temporal+1 = 25 frames -> 5 blocks of 5 frames each
+
+ base_num_latent_frames = (97-1)//vae_scale_factor_temporal+1 = 25 → blocks = 25//5 = 5 blocks
+ This 5 blocks means the maximum context length of the model is 25 frames in the latent space.
+
+ Asynchronous Processing Timeline:
+ ┌─────────────────────────────────────────────────────────────────┐
+ │ Steps: 1 6 11 16 21 26 31 36 41 46 50 │
+ │ Block 1: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ │ Block 2: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ │ Block 3: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ │ Block 4: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ │ Block 5: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ └─────────────────────────────────────────────────────────────────┘
+
+ For Long Videos (num_frames > base_num_frames):
+ base_num_frames acts as the "sliding window size" for processing long videos.
+
+ Example: 257-frame video with base_num_frames=97, overlap_history=17
+ ┌──── Iteration 1 (frames 1-97) ────┐
+ │ Processing window: 97 frames │ → 5 blocks, async processing
+ │ Generates: frames 1-97 │
+ └───────────────────────────────────┘
+ ┌────── Iteration 2 (frames 81-177) ──────┐
+ │ Processing window: 97 frames │
+ │ Overlap: 17 frames (81-97) from prev │ → 5 blocks, async processing
+ │ Generates: frames 98-177 │
+ └─────────────────────────────────────────┘
+ ┌────── Iteration 3 (frames 161-257) ──────┐
+ │ Processing window: 97 frames │
+ │ Overlap: 17 frames (161-177) from prev │ → 5 blocks, async processing
+ │ Generates: frames 178-257 │
+ └──────────────────────────────────────────┘
+
+ Each iteration independently runs the asynchronous processing with its own 5 blocks.
+ base_num_frames controls:
+ 1. Memory usage (larger window = more VRAM)
+ 2. Model context length (must match training constraints)
+ 3. Number of blocks per iteration (base_num_latent_frames // causal_block_size)
+
+ Each block takes 30 steps to complete denoising.
+ Block N starts at step: 1 + (N-1) x ar_step
+ Total steps: 30 + (5-1) x 5 = 50 steps
+
+
+ Synchronous mode (ar_step=0) would process all blocks/frames simultaneously:
+ ┌──────────────────────────────────────────────┐
+ │ Steps: 1 ... 30 │
+ │ All blocks: [■■■■■■■■■■■■■■■■■■■■■■■■■■■■■■] │
+ └──────────────────────────────────────────────┘
+ Total steps: 30 steps
+
+
+ An example on how the step matrix is constructed for asynchronous processing:
+ Given the parameters: (num_inference_steps=30, flow_shift=8, num_frames=97, ar_step=5, causal_block_size=5)
+ - num_latent_frames = (97 frames - 1) // (4 temporal downsampling) + 1 = 25
+ - step_template = [999, 995, 991, 986, 980, 975, 969, 963, 956, 948,
+ 941, 932, 922, 912, 901, 888, 874, 859, 841, 822,
+ 799, 773, 743, 708, 666, 615, 551, 470, 363, 216]
+
+ The algorithm creates a 50x25 step_matrix where:
+ - Row 1: [999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
+ - Row 2: [995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
+ - Row 3: [991, 991, 991, 991, 991, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
+ - ...
+ - Row 7: [969, 969, 969, 969, 969, 995, 995, 995, 995, 995, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999, 999]
+ - ...
+ - Row 21: [799, 799, 799, 799, 799, 888, 888, 888, 888, 888, 941, 941, 941, 941, 941, 975, 975, 975, 975, 975, 999, 999, 999, 999, 999]
+ - ...
+ - Row 35: [ 0, 0, 0, 0, 0, 216, 216, 216, 216, 216, 666, 666, 666, 666, 666, 822, 822, 822, 822, 822, 901, 901, 901, 901, 901]
+ - ...
+ - Row 42: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 551, 551, 551, 551, 551, 773, 773, 773, 773, 773]
+ - ...
+ - Row 50: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 216, 216, 216, 216, 216]
+
+ Detailed Row 6 Analysis:
+ - step_matrix[5]: [ 975, 975, 975, 975, 975, 999, 999, 999, 999, 999, 999, ..., 999]
+ - step_index[5]: [ 6, 6, 6, 6, 6, 1, 1, 1, 1, 1, 0, ..., 0]
+ - step_update_mask[5]: [True,True,True,True,True,True,True,True,True,True,False, ...,False]
+ - valid_interval[5]: (0, 25)
+
+ Key Pattern: Block i lags behind Block i-1 by exactly ar_step=5 timesteps, creating the
+ staggered "diffusion forcing" effect where later blocks condition on cleaner earlier blocks.
+
+### Text-to-Video Generation
+
+The example below demonstrates how to generate a video from text.
+
+
+
+
+Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
+
+From the original repo:
+>You can use --ar_step 5 to enable asynchronous inference. When asynchronous inference, --causal_block_size 5 is recommended while it is not supposed to be set for synchronous generation... Asynchronous inference will take more steps to diffuse the whole sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous inference may improve the instruction following and visual consistent performance.
+
+```py
+# pip install ftfy
+import torch
+from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video
+
+vae = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32)
+transformer = AutoModel.from_pretrained("Skywork/SkyReels-V2-DF-14B-540P-Diffusers", subfolder="transformer", torch_dtype=torch.bfloat16)
+
+pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
+ "Skywork/SkyReels-V2-DF-14B-540P-Diffusers",
+ vae=vae,
+ transformer=transformer,
+ torch_dtype=torch.bfloat16
+)
+flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+pipeline = pipeline.to("cuda")
+
+prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+output = pipeline(
+ prompt=prompt,
+ num_inference_steps=30,
+ height=544, # 720 for 720P
+ width=960, # 1280 for 720P
+ num_frames=97,
+ base_num_frames=97, # 121 for 720P
+ ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
+ causal_block_size=5, # Number of frames in each block for asynchronous processing
+ overlap_history=None, # Number of frames to overlap for smooth transitions in long videos; 17 for long video generations
+ addnoise_condition=20, # Improves consistency in long video generation
+).frames[0]
+export_to_video(output, "T2V.mp4", fps=24, quality=8)
+```
+
+
+
+
+### First-Last-Frame-to-Video Generation
+
+The example below demonstrates how to use the image-to-video pipeline to generate a video using a text description, a starting frame, and an ending frame.
+
+
+
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingImageToVideoPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_image
+
+
+model_id = "Skywork/SkyReels-V2-DF-14B-720P-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipeline = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
+ model_id, vae=vae, torch_dtype=torch.bfloat16
+)
+flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+pipeline.to("cuda")
+
+first_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_first_frame.png")
+last_frame = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flf2v_input_last_frame.png")
+
+def aspect_ratio_resize(image, pipeline, max_area=720 * 1280):
+ aspect_ratio = image.height / image.width
+ mod_value = pipeline.vae_scale_factor_spatial * pipeline.transformer.config.patch_size[1]
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+ return image, height, width
+
+def center_crop_resize(image, height, width):
+ # Calculate resize ratio to match first frame dimensions
+ resize_ratio = max(width / image.width, height / image.height)
+
+ # Resize the image
+ width = round(image.width * resize_ratio)
+ height = round(image.height * resize_ratio)
+ size = [width, height]
+ image = TF.center_crop(image, size)
+
+ return image, height, width
+
+first_frame, height, width = aspect_ratio_resize(first_frame, pipeline)
+if last_frame.size != first_frame.size:
+ last_frame, _, _ = center_crop_resize(last_frame, height, width)
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+
+output = pipeline(
+ image=first_frame, last_image=last_frame, prompt=prompt, height=height, width=width, guidance_scale=5.0
+).frames[0]
+export_to_video(output, "output.mp4", fps=24, quality=8)
+```
+
+
+
+
+
+### Video-to-Video Generation
+
+
+
+
+`SkyReelsV2DiffusionForcingVideoToVideoPipeline` extends a given video.
+
+```python
+import numpy as np
+import torch
+import torchvision.transforms.functional as TF
+from diffusers import AutoencoderKLWan, SkyReelsV2DiffusionForcingVideoToVideoPipeline, UniPCMultistepScheduler
+from diffusers.utils import export_to_video, load_video
+
+
+model_id = "Skywork/SkyReels-V2-DF-14B-540P-Diffusers"
+vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
+pipeline = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
+ model_id, vae=vae, torch_dtype=torch.bfloat16
+)
+flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config, flow_shift=flow_shift)
+pipeline.to("cuda")
+
+video = load_video("input_video.mp4")
+
+prompt = "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird's feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
+
+output = pipeline(
+ video=video, prompt=prompt, height=544, width=960, guidance_scale=5.0,
+ num_inference_steps=30, num_frames=257, base_num_frames=97#, ar_step=5, causal_block_size=5,
+).frames[0]
+export_to_video(output, "output.mp4", fps=24, quality=8)
+# Total frames will be the number of frames of given video + 257
+```
+
+
+
+
+
+## Notes
+
+- SkyReels-V2 supports LoRAs with [`~loaders.SkyReelsV2LoraLoaderMixin.load_lora_weights`].
+
+
+ Show example code
+
+ ```py
+ # pip install ftfy
+ import torch
+ from diffusers import AutoModel, SkyReelsV2DiffusionForcingPipeline
+ from diffusers.utils import export_to_video
+
+ vae = AutoModel.from_pretrained(
+ "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", subfolder="vae", torch_dtype=torch.float32
+ )
+ pipeline = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
+ "Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers", vae=vae, torch_dtype=torch.bfloat16
+ )
+ pipeline.to("cuda")
+
+ pipeline.load_lora_weights("benjamin-paine/steamboat-willie-1.3b", adapter_name="steamboat-willie")
+ pipeline.set_adapters("steamboat-willie")
+
+ pipeline.enable_model_cpu_offload()
+
+ # use "steamboat willie style" to trigger the LoRA
+ prompt = """
+ steamboat willie style, golden era animation, The camera rushes from far to near in a low-angle shot,
+ revealing a white ferret on a log. It plays, leaps into the water, and emerges, as the camera zooms in
+ for a close-up. Water splashes berry bushes nearby, while moss, snow, and leaves blanket the ground.
+ Birch trees and a light blue sky frame the scene, with ferns in the foreground. Side lighting casts dynamic
+ shadows and warm highlights. Medium composition, front view, low angle, with depth of field.
+ """
+
+ output = pipeline(
+ prompt=prompt,
+ num_frames=97,
+ guidance_scale=6.0,
+ ).frames[0]
+ export_to_video(output, "output.mp4", fps=24)
+ ```
+
+
+
+
+## SkyReelsV2DiffusionForcingPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingPipeline
+ - all
+ - __call__
+
+## SkyReelsV2DiffusionForcingImageToVideoPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingImageToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2DiffusionForcingVideoToVideoPipeline
+
+[[autodoc]] SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2Pipeline
+
+[[autodoc]] SkyReelsV2Pipeline
+ - all
+ - __call__
+
+## SkyReelsV2ImageToVideoPipeline
+
+[[autodoc]] SkyReelsV2ImageToVideoPipeline
+ - all
+ - __call__
+
+## SkyReelsV2PipelineOutput
+
+[[autodoc]] pipelines.skyreels_v2.pipeline_output.SkyReelsV2PipelineOutput
\ No newline at end of file
diff --git a/docs/source/en/api/pipelines/stable_diffusion/gligen.md b/docs/source/en/api/pipelines/stable_diffusion/gligen.md
index 73be0b4ca8..e9704fc1de 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/gligen.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/gligen.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# GLIGEN (Grounded Language-to-Image Generation)
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
index 4d7fda2a0c..75f052b08f 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/k_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# K-Diffusion
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
diff --git a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
index 9f54538968..4c52ed90f0 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/ldm3d_diffusion.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text-to-(RGB, depth)
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
index 9eb58a49d7..211b26889a 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md
@@ -31,7 +31,7 @@ _As the model is gated, before using it with diffusers you first need to go to t
Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
diff --git a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
index ac5b97b672..1736491107 100644
--- a/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
+++ b/docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_safe.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Safe Stable Diffusion
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
diff --git a/docs/source/en/api/pipelines/text_to_video.md b/docs/source/en/api/pipelines/text_to_video.md
index 116aea736f..7faf88d133 100644
--- a/docs/source/en/api/pipelines/text_to_video.md
+++ b/docs/source/en/api/pipelines/text_to_video.md
@@ -10,11 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
-
-
-🧪 This pipeline is for research purposes only.
-
-
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# Text-to-video
diff --git a/docs/source/en/api/pipelines/text_to_video_zero.md b/docs/source/en/api/pipelines/text_to_video_zero.md
index 7966f43390..5fe3789d82 100644
--- a/docs/source/en/api/pipelines/text_to_video_zero.md
+++ b/docs/source/en/api/pipelines/text_to_video_zero.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# Text2Video-Zero
diff --git a/docs/source/en/api/pipelines/unclip.md b/docs/source/en/api/pipelines/unclip.md
index c9a3164226..8011a4b533 100644
--- a/docs/source/en/api/pipelines/unclip.md
+++ b/docs/source/en/api/pipelines/unclip.md
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# unCLIP
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
diff --git a/docs/source/en/api/pipelines/unidiffuser.md b/docs/source/en/api/pipelines/unidiffuser.md
index bce55b67ed..7d767f2db5 100644
--- a/docs/source/en/api/pipelines/unidiffuser.md
+++ b/docs/source/en/api/pipelines/unidiffuser.md
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
# UniDiffuser
diff --git a/docs/source/en/api/pipelines/wan.md b/docs/source/en/api/pipelines/wan.md
index 18b8207e3b..81cd242151 100644
--- a/docs/source/en/api/pipelines/wan.md
+++ b/docs/source/en/api/pipelines/wan.md
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
```py
# pip install ftfy
import torch
- from diffusers import WanPipeline, AutoModel
+ from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
- vae = AutoModel.from_single_file(
+ vae = AutoencoderKLWan.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
)
- transformer = AutoModel.from_single_file(
+ transformer = WanTransformer3DModel.from_single_file(
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
torch_dtype=torch.bfloat16
)
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
index 561df2017d..2be3631d84 100644
--- a/docs/source/en/api/pipelines/wuerstchen.md
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
# Würstchen
+> [!WARNING]
+> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
+
diff --git a/docs/source/en/api/quantization.md b/docs/source/en/api/quantization.md
index 713748ae5c..31271f1722 100644
--- a/docs/source/en/api/quantization.md
+++ b/docs/source/en/api/quantization.md
@@ -27,19 +27,19 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui
## BitsAndBytesConfig
-[[autodoc]] BitsAndBytesConfig
+[[autodoc]] quantizers.quantization_config.BitsAndBytesConfig
## GGUFQuantizationConfig
-[[autodoc]] GGUFQuantizationConfig
+[[autodoc]] quantizers.quantization_config.GGUFQuantizationConfig
## QuantoConfig
-[[autodoc]] QuantoConfig
+[[autodoc]] quantizers.quantization_config.QuantoConfig
## TorchAoConfig
-[[autodoc]] TorchAoConfig
+[[autodoc]] quantizers.quantization_config.TorchAoConfig
## DiffusersQuantizer
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 04e907a542..0aca1d22c1 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -12,37 +12,24 @@ specific language governing permissions and limitations under the License.
-
+
# Diffusers
-🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or want to train your own diffusion model, 🤗 Diffusers is a modular toolbox that supports both. Our library is designed with a focus on [usability over performance](conceptual/philosophy#usability-over-performance), [simple over easy](conceptual/philosophy#simple-over-easy), and [customizability over abstractions](conceptual/philosophy#tweakable-contributorfriendly-over-abstraction).
+Diffusers is a library of state-of-the-art pretrained diffusion models for generating videos, images, and audio.
-The library has three main components:
+The library revolves around the [`DiffusionPipeline`], an API designed for:
-- State-of-the-art diffusion pipelines for inference with just a few lines of code. There are many pipelines in 🤗 Diffusers, check out the table in the pipeline [overview](api/pipelines/overview) for a complete list of available pipelines and the task they solve.
-- Interchangeable [noise schedulers](api/schedulers/overview) for balancing trade-offs between generation speed and quality.
-- Pretrained [models](api/models) that can be used as building blocks, and combined with schedulers, for creating your own end-to-end diffusion systems.
+- easy inference with only a few lines of code
+- flexibility to mix-and-match pipeline components (models, schedulers)
+- loading and using adapters like LoRA
-
+Diffusers also comes with optimizations - such as offloading and quantization - to ensure even the largest models are accessible on memory-constrained devices. If memory is not an issue, Diffusers supports torch.compile to boost inference speed.
+
+Get started right away with a Diffusers model on the [Hub](https://huggingface.co/models?library=diffusers&sort=trending) today!
+
+## Learn
+
+If you're a beginner, we recommend starting with the [Hugging Face Diffusion Models Course](https://huggingface.co/learn/diffusion-course/unit0/1). You'll learn the theory behind diffusion models, and learn how to use the Diffusers library to generate images, fine-tune your own models, and more.
diff --git a/docs/source/en/modular_diffusers/auto_pipeline_blocks.md b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md
new file mode 100644
index 0000000000..50c3250512
--- /dev/null
+++ b/docs/source/en/modular_diffusers/auto_pipeline_blocks.md
@@ -0,0 +1,316 @@
+
+
+# AutoPipelineBlocks
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+`AutoPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that automatically selects which sub-blocks to run based on the inputs provided at runtime, creating conditional workflows that adapt to different scenarios. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use.
+
+In this tutorial, we will show you how to create an `AutoPipelineBlocks` and learn more about how the conditional selection works.
+
+
+
+Other types of multi-blocks include [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
+
+Additionally, like all `ModularPipelineBlocks`, `AutoPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
+
+
+
+For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided.
+
+Let's see an example. We'll use the helper function from the [PipelineBlock guide](./pipeline_block.md) to create our blocks:
+
+**Helper Function**
+
+```py
+from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
+import torch
+
+def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
+ class TestBlock(PipelineBlock):
+ model_name = "test"
+
+ @property
+ def inputs(self):
+ return inputs
+
+ @property
+ def intermediate_inputs(self):
+ return intermediate_inputs
+
+ @property
+ def intermediate_outputs(self):
+ return intermediate_outputs
+
+ @property
+ def description(self):
+ return description if description is not None else ""
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ if block_fn is not None:
+ block_state = block_fn(block_state, state)
+ self.set_block_state(state, block_state)
+ return components, state
+
+ return TestBlock
+```
+
+Now let's create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines.
+
+
+```py
+from diffusers.modular_pipelines import AutoPipelineBlocks
+
+# These are dummy blocks and we only focus on "inputs" for our purpose
+inputs = [InputParam(name="prompt")]
+# block_fn prints out which workflow is running so we can see the execution order at runtime
+block_fn = lambda x, y: print("running the text-to-image workflow")
+block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!")
+
+inputs = [InputParam(name="prompt"), InputParam(name="image")]
+block_fn = lambda x, y: print("running the image-to-image workflow")
+block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!")
+
+inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
+block_fn = lambda x, y: print("running the inpaint workflow")
+block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!")
+
+class AutoImageBlocks(AutoPipelineBlocks):
+ # List of sub-block classes to choose from
+ block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
+ # Names for each block in the same order
+ block_names = ["inpaint", "img2img", "text2img"]
+ # Trigger inputs that determine which block to run
+ # - "mask" triggers inpaint workflow
+ # - "image" triggers img2img workflow (but only if mask is not provided)
+ # - if none of above, runs the text2img workflow (default)
+ block_trigger_inputs = ["mask", "image", None]
+ # Description is extremely important for AutoPipelineBlocks
+ @property
+ def description(self):
+ return (
+ "Pipeline generates images given different types of conditions!\n"
+ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ + " - inpaint workflow is run when `mask` is provided.\n"
+ + " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ + " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
+ )
+
+# Create the blocks
+auto_blocks = AutoImageBlocks()
+# convert to pipeline
+auto_pipeline = auto_blocks.init_pipeline()
+```
+
+Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime.
+
+```py
+AutoImageBlocks(
+ Class: AutoPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: ['mask', 'image']
+ ====================================================================================================
+
+
+ Description: Pipeline generates images given different types of conditions!
+ This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
+ - inpaint workflow is run when `mask` is provided.
+ - img2img workflow is run when `image` is provided (but only when `mask` is not provided).
+ - text2img workflow is run when neither `image` nor `mask` is provided.
+
+
+
+ Sub-Blocks:
+ • inpaint [trigger: mask] (TestBlock)
+ Description: I'm a inpaint workflow!
+
+ • img2img [trigger: image] (TestBlock)
+ Description: I'm a image-to-image workflow!
+
+ • text2img [default] (TestBlock)
+ Description: I'm a text-to-image workflow!
+
+)
+```
+
+Check out the documentation with `print(auto_pipeline.doc)`:
+
+```py
+>>> print(auto_pipeline.doc)
+class AutoImageBlocks
+
+ Pipeline generates images given different types of conditions!
+ This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
+ - inpaint workflow is run when `mask` is provided.
+ - img2img workflow is run when `image` is provided (but only when `mask` is not provided).
+ - text2img workflow is run when neither `image` nor `mask` is provided.
+
+ Inputs:
+
+ prompt (`None`, *optional*):
+
+ image (`None`, *optional*):
+
+ mask (`None`, *optional*):
+```
+
+There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right?
+
+This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks.
+
+Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running:
+
+```py
+>>> _ = auto_pipeline(image="image", mask="mask")
+running the inpaint workflow
+>>> _ = auto_pipeline(image="image")
+running the image-to-image workflow
+>>> _ = auto_pipeline(prompt="prompt")
+running the text-to-image workflow
+>>> _ = auto_pipeline(image="prompt", mask="mask")
+running the inpaint workflow
+```
+
+However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines.
+
+Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime.
+
+```py
+from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
+inputs = [InputParam(name="ip_adapter_image")]
+block_fn = lambda x, y: print("running the ip-adapter workflow")
+block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!")
+
+class AutoIPAdapter(AutoPipelineBlocks):
+ block_classes = [block_ipa_cls]
+ block_names = ["ip-adapter"]
+ block_trigger_inputs = ["ip_adapter_image"]
+ @property
+ def description(self):
+ return "Run IP Adapter step if `ip_adapter_image` is provided."
+```
+
+Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`:
+
+```py
+auto_ipa_blocks = AutoIPAdapter()
+blocks_dict = InsertableDict()
+blocks_dict["ip-adapter"] = auto_ipa_blocks
+blocks_dict["image-generation"] = auto_blocks
+all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
+pipeline = all_blocks.init_pipeline()
+```
+
+Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases.
+
+```py
+>>> all_blocks
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: ['image', 'mask', 'ip_adapter_image']
+ Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`).
+ ====================================================================================================
+
+
+ Description:
+
+
+ Sub-Blocks:
+ [0] ip-adapter (AutoIPAdapter)
+ Description: Run IP Adapter step if `ip_adapter_image` is provided.
+
+
+ [1] image-generation (AutoImageBlocks)
+ Description: Pipeline generates images given different types of conditions!
+ This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
+ - inpaint workflow is run when `mask` is provided.
+ - img2img workflow is run when `image` is provided (but only when `mask` is not provided).
+ - text2img workflow is run when neither `image` nor `mask` is provided.
+
+
+)
+
+```
+
+This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs.
+
+Let's try some examples:
+
+`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block.
+
+```py
+>>> all_blocks.get_execution_blocks('mask')
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ Description:
+
+
+ Sub-Blocks:
+ [0] image-generation (TestBlock)
+ Description: I'm a inpaint workflow!
+
+)
+```
+
+Let's also actually run the pipeline to confirm:
+
+```py
+>>> _ = pipeline(mask="mask")
+skipping auto block: AutoIPAdapter
+running the inpaint workflow
+```
+
+Try a few more:
+
+```py
+print(f"inputs: ip_adapter_image:")
+blocks_select = all_blocks.get_execution_blocks('ip_adapter_image')
+print(f"expected_execution_blocks: {blocks_select}")
+print(f"actual execution blocks:")
+_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt")
+# expect to see ip-adapter + text2img
+
+print(f"inputs: image:")
+blocks_select = all_blocks.get_execution_blocks('image')
+print(f"expected_execution_blocks: {blocks_select}")
+print(f"actual execution blocks:")
+_ = pipeline(image="image", prompt="prompt")
+# expect to see img2img
+
+print(f"inputs: prompt:")
+blocks_select = all_blocks.get_execution_blocks('prompt')
+print(f"expected_execution_blocks: {blocks_select}")
+print(f"actual execution blocks:")
+_ = pipeline(prompt="prompt")
+# expect to see text2img (prompt is not a trigger input so fallback to default)
+
+print(f"inputs: mask + ip_adapter_image:")
+blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image')
+print(f"expected_execution_blocks: {blocks_select}")
+print(f"actual execution blocks:")
+_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image")
+# expect to see ip-adapter + inpaint
+```
+
+In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions.
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/components_manager.md b/docs/source/en/modular_diffusers/components_manager.md
new file mode 100644
index 0000000000..15b6c66b9b
--- /dev/null
+++ b/docs/source/en/modular_diffusers/components_manager.md
@@ -0,0 +1,514 @@
+
+
+# Components Manager
+
+
+
+🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading (i.e. `enable_model_cpu_offload` and `enable_sequential_cpu_offload`) which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows.
+
+
+## Basic Operations
+
+Let's start with the most basic operations. First, create a Components Manager:
+
+```py
+from diffusers import ComponentsManager
+comp = ComponentsManager()
+```
+
+Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function):
+
+```py
+from diffusers import AutoModel
+text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+# Returns component_id like 'text_encoder_139917733042864'
+component_id = comp.add("text_encoder", text_encoder)
+```
+
+You can view all registered components and their metadata:
+
+```py
+>>> comp
+Components:
+===============================================================================================================================================
+Models:
+-----------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+-----------------------------------------------------------------------------------------------------------------------------------------------
+text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
+-----------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+And remove components using their unique ID:
+
+```py
+comp.remove("text_encoder_139917733042864")
+```
+
+## Duplicate Detection
+
+The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice.
+
+When you try to add the same object twice, the manager will warn you and return the existing ID:
+
+```py
+>>> comp.add("text_encoder", text_encoder)
+'text_encoder_139917733042864'
+>>> comp.add("text_encoder", text_encoder)
+ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864'
+'text_encoder_139917733042864'
+```
+
+Even if you add the same object under a different name, it will still be detected as a duplicate:
+
+```py
+>>> comp.add("clip", text_encoder)
+ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864'
+To remove a duplicate, call `components_manager.remove('
')`.
+'clip_139917733042864'
+```
+
+However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example:
+
+```py
+>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
+>>> comp.add("text_encoder", text_encoder_2)
+'text_encoder_139917732983664'
+```
+
+This creates a problem - you now have two copies of the same model consuming double the memory:
+
+```py
+>>> comp
+Components:
+===============================================================================================================================================
+Models:
+-----------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+-----------------------------------------------------------------------------------------------------------------------------------------------
+text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
+clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
+text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
+-----------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint:
+
+```py
+from diffusers import ComponentSpec, ComponentsManager
+from transformers import CLIPTextModel
+comp = ComponentsManager()
+
+# Create ComponentSpec for the first text encoder
+spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
+# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder)
+spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
+
+# Load and add both components - the manager will detect they're the same model
+comp.add("text_encoder", spec.load())
+comp.add("text_encoder_duplicated", spec_duplicated.load())
+```
+
+Now the manager detects the duplicate and warns you:
+
+```out
+ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('')`.
+'text_encoder_duplicated_139917580682672'
+```
+
+Both models now show the same `load_id`, making it clear they're the same model:
+
+```py
+>>> comp
+Components:
+======================================================================================================================================================================================================
+Models:
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
+text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+## Collections
+
+Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed.
+
+Here's how collections work in practice:
+
+```py
+comp = ComponentsManager()
+# Create ComponentSpec for the first UNet (SDXL base)
+spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
+# Create ComponentSpec for a different UNet (Juggernaut-XL)
+spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
+
+# Add both UNets to the same collection - the second one will replace the first
+comp.add("unet", spec.load(), collection="sdxl")
+comp.add("unet", spec2.load(), collection="sdxl")
+```
+
+The manager automatically removes the old UNet and adds the new one:
+
+```out
+ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888
+'unet_139917723893136'
+```
+
+Only one UNet remains in the collection:
+
+```py
+>>> comp
+Components:
+====================================================================================================================================================================
+Models:
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------
+unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed.
+
+## Retrieving Components
+
+The Components Manager provides several methods to retrieve registered components.
+
+The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use:
+- exact matches like `comp.get_one(name="unet")`
+- wildcards like `comp.get_one(name="unet*")` for components starting with "unet"
+- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet"
+- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae".
+
+Optionally, You can add collection and load_id as filters e.g. `comp.get_one(name="unet", collection="sdxl")`. If multiple components match, `get_one()` throws an error.
+
+Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`.
+
+```py
+# Get components by name list
+component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
+# Returns: {"text_encoder": component1, "unet": component2, "vae": component3}
+```
+
+## Using Components Manager with Modular Pipelines
+
+The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter:
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+comp = ComponentsManager()
+pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
+```
+
+By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty:
+
+```py
+>>> comp
+Components:
+==================================================
+No components registered.
+==================================================
+```
+
+When you load components on the pipeline, they are automatically registered in the Components Manager:
+
+```py
+>>> pipe.load_components(names="unet")
+>>> comp
+Components:
+==============================================================================================================================================================
+Models:
+--------------------------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+--------------------------------------------------------------------------------------------------------------------------------------------------------------
+unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
+--------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection:
+
+```py
+# Load all default components
+>>> pipe.load_default_components()
+
+# Create a second pipeline using the same Components Manager but with a different collection
+>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
+```
+
+As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager:
+
+```py
+# Get the list of components that pipe2 needs to load
+>>> pipe2.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
+
+# Retrieve all required components from the Components Manager
+>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
+
+# Update the pipeline with the retrieved components
+>>> pipe2.update_components(**comp_dict)
+```
+
+The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates:
+
+```out
+ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400'
+ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424'
+ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504'
+ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904'
+ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304'
+ComponentsManager: component 'unet' already exists as 'unet_139917580609632'
+ComponentsManager: component 'vae' already exists as 'vae_139917722459040'
+ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408'
+ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432'
+```
+
+
+The pipeline is now fully loaded:
+
+```py
+# null_component_names return empty list, meaning everything are loaded
+>>> pipe2.null_component_names
+[]
+```
+
+No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines:
+```py
+>>> comp
+Components:
+========================================================================================================================================================================================
+Models:
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1
+ | | | | | | test2
+text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1
+ | | | | | | test2
+unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
+ | | | | | | test2
+controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1
+ | | | | | | test2
+vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1
+ | | | | | | test2
+image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1
+ | | | | | | test2
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Other Components:
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+ID | Class | Collection
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+tokenizer_139917580599504 | CLIPTokenizer | test1
+ | | test2
+scheduler_139916266559408 | EulerDiscreteScheduler | test1
+ | | test2
+tokenizer_2_139915763443904 | CLIPTokenizer | test1
+ | | test2
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+
+Additional Component Info:
+==================================================
+```
+
+
+## Automatic Memory Management
+
+The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them:
+
+```py
+comp.enable_auto_cpu_offload(device="cuda")
+```
+
+When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow.
+
+
+
+## Practical Example: Building Modular Workflows with Component Reuse
+
+Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components.
+
+In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline.
+
+Let's create a modular text-to-image workflow by separating it into three workflows: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images.
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
+
+# Create modular blocks and separate text encoding and decoding steps
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
+text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
+decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
+```
+
+Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection
+
+Since we now have 3 different workflows that share components, we create a separate pipeline that serves as a dedicated loader to load all the components, register them to the component manager, and then reuse them across different workflows.
+
+```py
+from diffusers import ComponentsManager, ModularPipeline
+
+# Set up Components Manager with auto offloading
+components = ComponentsManager()
+components.enable_auto_cpu_offload(device="cuda")
+
+# Create a new pipeline to load the components
+t2i_repo = "YiYiXu/modular-demo-auto"
+t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
+
+# convert the 3 blocks into pipelines and attach the same components manager to all 3
+text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
+decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
+t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
+```
+
+Load all components into the loader pipeline, they should all be automatically registered to Components Manager under the "t2i" collection:
+
+```py
+# Load all components (including IP-Adapter and ControlNet for later use)
+t2i_loader_pipe.load_default_components(torch_dtype=torch.float16)
+```
+
+Now distribute the loaded components to each pipeline:
+
+```py
+# Get VAE for decoder (using get_one since there's only one)
+vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
+decoder_node.update_components(vae=vae)
+
+# Get text components for text node (using get_components_by_names for multiple components)
+text_components = components.get_components_by_names(text_node.null_component_names)
+text_node.update_components(**text_components)
+
+# Get remaining components for t2i pipeline
+t2i_components = components.get_components_by_names(t2i_pipe.null_component_names)
+t2i_pipe.update_components(**t2i_components)
+```
+
+Now we can generate images using our modular workflow:
+
+```py
+# Generate text embeddings
+prompt = "an astronaut"
+text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
+
+# Generate latents and decode to image
+generator = torch.Generator(device="cuda").manual_seed(0)
+latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
+image = decoder_node(latents=latents_t2i, output="images")[0]
+image.save("modular_part2_t2i.png")
+```
+
+Let's add a LoRA:
+
+```py
+# Load LoRA weights
+>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
+>>> components
+Components:
+============================================================================================================================================================
+...
+Additional Component Info:
+==================================================
+
+unet:
+ Adapters: ['toy_face']
+```
+
+You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings.
+
+```py
+# Generate with LoRA (reusing existing text embeddings)
+generator = torch.Generator(device="cuda").manual_seed(0)
+latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
+image = decoder_node(latents=latents_lora, output="images")[0]
+image.save("modular_part2_lora.png")
+```
+
+
+Now let's create a refiner pipeline that reuses components from our text-to-image workflow:
+
+```py
+# Create refiner blocks (removing image_encoder and decode since we work with latents)
+refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
+refiner_blocks.sub_blocks.pop("image_encoder")
+refiner_blocks.sub_blocks.pop("decode")
+
+# Create refiner pipeline with different repo and collection,
+# Attach the same component manager to it
+refiner_repo = "YiYiXu/modular_refiner"
+refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
+```
+
+We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection.
+
+```py
+# Load only the refiner UNet (different from t2i UNet)
+refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
+
+# Reuse components from t2i pipeline using pattern matching
+reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
+refiner_pipe.update_components(**reuse_components)
+```
+
+When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows.
+
+Now we can refine any of our generated latents:
+
+```py
+# Refine all our different latents
+refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
+refined_image = decoder_node(latents=refined_latents, output="images")[0]
+refined_image.save("modular_part2_t2i_refine_out.png")
+
+refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
+refined_image = decoder_node(latents=refined_latents, output="images")[0]
+refined_image.save("modular_part2_lora_refine_out.png")
+```
+
+
+Here are the results from our modular pipeline examples.
+
+#### Base Text-to-Image Generation
+| Base Text-to-Image | Base Text-to-Image (Refined) |
+|-------------------|------------------------------|
+|  |  |
+
+#### LoRA
+| LoRA | LoRA (Refined) |
+|-------------------|------------------------------|
+|  |  |
+
diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md
new file mode 100644
index 0000000000..cb7b87552a
--- /dev/null
+++ b/docs/source/en/modular_diffusers/end_to_end_guide.md
@@ -0,0 +1,648 @@
+
+
+# End-to-End Developer Guide: Building with Modular Diffusers
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+
+In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
+
+We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
+
+1. **Start with an existing pipeline as a base**
+ - Identify which existing pipeline is most similar to the one you want to implement
+ - Determine what part of the pipeline needs modification
+
+2. **Build a working pipeline structure first**
+ - Assemble the complete pipeline structure
+ - Use existing blocks wherever possible
+ - For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
+
+3. **Set up an example**
+ - Create a simple inference script with expected inputs/outputs
+
+4. **Implement your custom logic and test incrementally**
+ - Add the custom logics the blocks you want to change
+ - Test incrementally, and inspect pipeline states and debug as needed
+
+Let's see how this works with the Differential Diffusion example.
+
+
+## Differential Diffusion Pipeline
+
+### Start with an existing pipeline
+
+Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
+
+```py
+>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+>>> IMAGE2IMAGE_BLOCKS = InsertableDict([
+... ("text_encoder", StableDiffusionXLTextEncoderStep),
+... ("image_encoder", StableDiffusionXLVaeEncoderStep),
+... ("input", StableDiffusionXLInputStep),
+... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+... ("denoise", StableDiffusionXLDenoiseStep),
+... ("decode", StableDiffusionXLDecodeStep)
+... ])
+```
+
+Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
+
+```py
+>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
+>>> print(denoise_blocks)
+```
+
+```out
+StableDiffusionXLDenoiseStep(
+ Class: StableDiffusionXLDenoiseLoopWrapper
+
+ Description: Denoise step that iteratively denoise the latents.
+ Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
+ At each iteration, it runs blocks defined in `sub_blocks` sequencially:
+ - `StableDiffusionXLLoopBeforeDenoiser`
+ - `StableDiffusionXLLoopDenoiser`
+ - `StableDiffusionXLLoopAfterDenoiser`
+ This block supports both text2img and img2img tasks.
+
+
+ Components:
+ scheduler (`EulerDiscreteScheduler`)
+ guider (`ClassifierFreeGuidance`)
+ unet (`UNet2DConditionModel`)
+
+ Sub-Blocks:
+ [0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
+ Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
+
+ [1] denoiser (StableDiffusionXLLoopDenoiser)
+ Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
+
+ [2] after_denoiser (StableDiffusionXLLoopAfterDenoiser)
+ Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
+
+)
+```
+
+Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
+
+Therefore, the key differences when it comes to pipeline implementation would be:
+1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
+2. The `denoise` step (which selectively applies denoising based on the change map)
+3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
+
+To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
+
+Here's a flowchart showing the pipeline structure and the changes we need to make:
+
+
+
+
+
+### Build a Working Pipeline Structure
+
+ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
+
+```py
+>>> # Copy existing blocks as placeholders
+>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
+... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
+... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
+...
+>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
+... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later"""
+... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
+```
+
+`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
+
+```py
+>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
+... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+```
+
+Now we can put together our differential diffusion pipeline.
+
+```py
+>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
+>>>
+>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
+>>> print(dd_blocks)
+>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
+```
+
+### Set up an example
+
+ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
+
+```py
+>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
+>>> dd_pipeline.to("cuda")
+```
+
+We will use this example script:
+
+```py
+>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+>>>
+>>> prompt = "a green pear"
+>>> negative_prompt = "blurry"
+>>>
+>>> image = dd_pipeline(
+... prompt=prompt,
+... negative_prompt=negative_prompt,
+... num_inference_steps=25,
+... diffdiff_map=mask,
+... image=image,
+... output="images"
+... )[0]
+>>>
+>>> image.save("diffdiff_out.png")
+```
+
+If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
+and you would get the same result as the original img2img pipeline.
+
+### implement your custom logic and test incrementally
+
+Let's modify the pipeline so that we can get expected result with this example script.
+
+We'll start with the `prepare_latents` step. The main changes are:
+- Requires a new user input `diffdiff_map`
+- Requires new component `mask_processor` to process the `diffdiff_map`
+- Requires new intermediate inputs:
+ - Need `timestep` instead of `latent_timestep` to precompute all the latents
+ - Need `num_inference_steps` to create the `diffdiff_masks`
+- create a new output `diffdiff_masks` and `original_latents`
+
+
+
+💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline.
+
+e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input.
+
+
+
+Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
+
+
+
+💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does.
+
+
+
+Here are the key changes we made to implement differential diffusion:
+
+**1. Modified `prepare_latents` step:**
+```diff
+class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
++ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
++ InputParam("diffdiff_map", required=True),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
++ InputParam("timesteps", type_hint=torch.Tensor),
++ InputParam("num_inference_steps", type_hint=int),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
++ OutputParam("original_latents", type_hint=torch.Tensor),
++ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+
+ def __call__(self, components, state: PipelineState):
+ # ... existing logic ...
++ # Process change map and create masks
++ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
++ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
++ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
++ block_state.original_latents = block_state.latents
+```
+
+**2. Modified `before_denoiser` step:**
+```diff
+class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
+ )
+
++ @property
++ def inputs(self) -> List[Tuple[str, Any]]:
++ return [
++ InputParam("denoising_start"),
++ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam("latents", required=True, type_hint=torch.Tensor),
++ InputParam("original_latents", type_hint=torch.Tensor),
++ InputParam("diffdiff_masks", type_hint=torch.Tensor),
+ ]
+
+ def __call__(self, components, block_state, i, t):
++ # Apply differential diffusion logic
++ if i == 0 and block_state.denoising_start is None:
++ block_state.latents = block_state.original_latents[:1]
++ else:
++ block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
++ block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
+
+ # ... rest of existing logic ...
+```
+
+That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
+
+Now we use the process we've prepred in step2 to build the pipeline and inspect it.
+
+
+```py
+>> dd_pipeline
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ Description:
+
+
+ Components:
+ text_encoder (`CLIPTextModel`)
+ text_encoder_2 (`CLIPTextModelWithProjection`)
+ tokenizer (`CLIPTokenizer`)
+ tokenizer_2 (`CLIPTokenizer`)
+ guider (`ClassifierFreeGuidance`)
+ vae (`AutoencoderKL`)
+ image_processor (`VaeImageProcessor`)
+ scheduler (`EulerDiscreteScheduler`)
+ mask_processor (`VaeImageProcessor`)
+ unet (`UNet2DConditionModel`)
+
+ Configs:
+ force_zeros_for_empty_prompt (default: True)
+ requires_aesthetics_score (default: False)
+
+ Blocks:
+ [0] text_encoder (StableDiffusionXLTextEncoderStep)
+ Description: Text Encoder step that generate text_embeddings to guide the image generation
+
+ [1] image_encoder (StableDiffusionXLVaeEncoderStep)
+ Description: Vae Encoder step that encode the input image into a latent representation
+
+ [2] input (StableDiffusionXLInputStep)
+ Description: Input processing step that:
+ 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
+ 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
+
+ All input tensors are expected to have either batch_size=1 or match the batch_size
+ of prompt_embeds. The tensors will be duplicated across the batch dimension to
+ have a final batch_size of batch_size * num_images_per_prompt.
+
+ [3] set_timesteps (StableDiffusionXLSetTimestepsStep)
+ Description: Step that sets the scheduler's timesteps for inference
+
+ [4] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
+ Description: Step that prepares the latents for the differential diffusion generation process
+
+ [5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
+ Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
+
+ [6] denoise (SDXLDiffDiffDenoiseStep)
+ Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
+
+ [7] decode (StableDiffusionXLDecodeStep)
+ Description: Step that decodes the denoised latents into images
+
+)
+```
+
+Run the example now, you should see an apple with its right half transformed into a green pear.
+
+
+
+
+## Adding IP-adapter
+
+We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](./auto_pipeline_blocks.md)
+
+We talked about how to add IP-adapter into your workflow in the [Modular Pipeline Guide](./modular_pipeline.md). Let's just go ahead to create the IP-adapter block.
+
+```py
+>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
+>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
+```
+
+We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
+
+```py
+>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
+```
+
+Take a look at the new diff-diff pipeline with ip-adapter!
+
+```py
+>>> print(dd_blocks)
+```
+
+The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor`
+
+```out
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: {'ip_adapter_image'}
+ Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
+ ====================================================================================================
+
+
+ Description:
+
+
+ Components:
+ image_encoder (`CLIPVisionModelWithProjection`)
+ feature_extractor (`CLIPImageProcessor`)
+ unet (`UNet2DConditionModel`)
+ guider (`ClassifierFreeGuidance`)
+ text_encoder (`CLIPTextModel`)
+ text_encoder_2 (`CLIPTextModelWithProjection`)
+ tokenizer (`CLIPTokenizer`)
+ tokenizer_2 (`CLIPTokenizer`)
+ vae (`AutoencoderKL`)
+ image_processor (`VaeImageProcessor`)
+ scheduler (`EulerDiscreteScheduler`)
+ mask_processor (`VaeImageProcessor`)
+
+ Configs:
+ force_zeros_for_empty_prompt (default: True)
+ requires_aesthetics_score (default: False)
+
+ Blocks:
+ [0] ip_adapter (StableDiffusionXLAutoIPAdapterStep)
+ Description: Run IP Adapter step if `ip_adapter_image` is provided.
+
+ [1] text_encoder (StableDiffusionXLTextEncoderStep)
+ Description: Text Encoder step that generate text_embeddings to guide the image generation
+
+ [2] image_encoder (StableDiffusionXLVaeEncoderStep)
+ Description: Vae Encoder step that encode the input image into a latent representation
+
+ [3] input (StableDiffusionXLInputStep)
+ Description: Input processing step that:
+ 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
+ 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
+
+ All input tensors are expected to have either batch_size=1 or match the batch_size
+ of prompt_embeds. The tensors will be duplicated across the batch dimension to
+ have a final batch_size of batch_size * num_images_per_prompt.
+
+ [4] set_timesteps (StableDiffusionXLSetTimestepsStep)
+ Description: Step that sets the scheduler's timesteps for inference
+
+ [5] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
+ Description: Step that prepares the latents for the differential diffusion generation process
+
+ [6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
+ Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
+
+ [7] denoise (SDXLDiffDiffDenoiseStep)
+ Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
+
+ [8] decode (StableDiffusionXLDecodeStep)
+ Description: Step that decodes the denoised latents into images
+
+)
+```
+
+Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output.
+
+
+```py
+>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
+>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
+>>>
+>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
+>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
+>>> dd_pipeline.loader.set_ip_adapter_scale(0.6)
+>>> dd_pipeline = dd_pipeline.to(device)
+>>>
+>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
+>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+>>>
+>>> prompt = "a green pear"
+>>> negative_prompt = "blurry"
+>>> generator = torch.Generator(device=device).manual_seed(42)
+>>>
+>>> image = dd_pipeline(
+... prompt=prompt,
+... negative_prompt=negative_prompt,
+... num_inference_steps=25,
+... generator=generator,
+... ip_adapter_image=ip_adapter_image,
+... diffdiff_map=mask,
+... image=image,
+... output="images"
+... )[0]
+```
+
+## Working with ControlNets
+
+What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are:
+1. A ControlNet input step that prepares the control condition
+2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet
+
+From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts.
+
+Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
+
+With this understanding, let's assemble the diffdiff-controlnet loop by combining the diffdiff before-denoiser step and controlnet denoiser step.
+
+```py
+>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
+... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+>>>
+>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
+>>> # print(controlnet_denoise)
+```
+
+We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union.
+
+
+```py
+>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
+>>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
+>>> print(control_input_block)
+```
+
+```out
+StableDiffusionXLAutoControlNetInputStep(
+ Class: AutoPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: ['control_image', 'control_mode']
+ ====================================================================================================
+
+
+ Description: Controlnet Input step that prepare the controlnet input.
+ This is an auto pipeline block that works for both controlnet and controlnet_union.
+ (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
+ - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
+
+
+ Components:
+ controlnet (`ControlNetUnionModel`)
+ control_image_processor (`VaeImageProcessor`)
+
+ Sub-Blocks:
+ • controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
+ Description: step that prepares inputs for the ControlNetUnion model
+
+ • controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep)
+ Description: step that prepare inputs for controlnet
+
+)
+
+```
+
+Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
+
+```py
+>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
+>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
+>>>
+>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
+>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
+>>> dd_pipeline = dd_pipeline.to(device)
+>>>
+>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
+>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
+>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
+>>>
+>>> prompt = "a green pear"
+>>> negative_prompt = "blurry"
+>>> generator = torch.Generator(device=device).manual_seed(42)
+>>>
+>>> image = dd_pipeline(
+... prompt=prompt,
+... negative_prompt=negative_prompt,
+... num_inference_steps=25,
+... generator=generator,
+... control_image=control_image,
+... controlnet_conditioning_scale=0.5,
+... diffdiff_map=mask,
+... image=image,
+... output="images"
+... )[0]
+```
+
+Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
+
+
+```py
+>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
+... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
+... block_names = ["controlnet_denoise", "denoise"]
+... block_trigger_inputs = ["controlnet_cond", None]
+```
+
+`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step.
+
+
+
+ Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
+
+
+
+Now you can create the differential diffusion preset that works with ip-adapter & controlnet.
+
+```py
+>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
+>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
+>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
+>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
+>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
+>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
+>>>
+>>> print(DIFFDIFF_AUTO_BLOCKS)
+```
+
+to use
+
+```py
+>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
+>>> dd_pipeline = dd_auto_blocks.init_pipeline(...)
+```
+## Creating a Modular Repo
+
+You can easily share your differential diffusion workflow on the Hub by creating a modular repo. This is one created using the code we just wrote together: https://huggingface.co/YiYiXu/modular-diffdiff
+
+To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily.
+
+```py
+dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
+```
+
+With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared.
+
+```py
+>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
+>>> import torch
+>>> from diffusers.utils import load_image
+>>>
+>>> repo_id = "YiYiXu/modular-diffdiff-0704"
+>>>
+>>> components = ComponentsManager()
+>>>
+>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff")
+>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
+>>> components.enable_auto_cpu_offload()
+```
+
+see more usage example on model card.
+
+## deploy a mellon node
+
+[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon]
diff --git a/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
new file mode 100644
index 0000000000..e95cdc7163
--- /dev/null
+++ b/docs/source/en/modular_diffusers/loop_sequential_pipeline_blocks.md
@@ -0,0 +1,194 @@
+
+
+# LoopSequentialPipelineBlocks
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+`LoopSequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. It is a multi-block that composes other blocks together in a loop, creating iterative workflows where blocks run multiple times with evolving state. It's particularly useful for denoising loops requiring repeated execution of the same blocks.
+
+
+
+Other types of multi-blocks include [SequentialPipelineBlocks](./sequential_pipeline_blocks.md) (for linear workflows) and [AutoPipelineBlocks](./auto_pipeline_blocks.md) (for conditional block selection). For information on creating individual blocks, see the [PipelineBlock guide](./pipeline_block.md).
+
+Additionally, like all `ModularPipelineBlocks`, `LoopSequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
+
+
+
+You could create a loop using `PipelineBlock` like this:
+
+```python
+class DenoiseLoop(PipelineBlock):
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ for t in range(block_state.num_inference_steps):
+ # ... loop logic here
+ pass
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+But in this tutorial, we will focus on how to use `LoopSequentialPipelineBlocks` to create a "composable" denoising loop where you can add or remove blocks within the loop or reuse the same loop structure with different block combinations.
+
+It involves two parts: a **loop wrapper** and **loop blocks**
+
+* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar.
+
+* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper.
+ - they run sequentially for each iteration of the loop
+ - they receive the current iteration index as an additional parameter
+ - they share the same block_state throughout the entire loop
+
+Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations.
+
+We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps:
+1. defining the loop wrapper class
+2. creating the loop blocks
+3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance
+
+**Step 1: Define the Loop Wrapper**
+
+To create a `LoopSequentialPipelineBlocks` class, you need to define:
+
+* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`)
+* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`)
+* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`)
+* `__call__` method: Defines the loop structure and iteration logic
+
+Here is an example of a loop wrapper:
+
+```py
+import torch
+from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam
+
+class LoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "test"
+ @property
+ def description(self):
+ return "I'm a loop!!"
+ @property
+ def loop_inputs(self):
+ return [InputParam(name="num_steps")]
+ @torch.no_grad()
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ # Loop structure - can be customized to your needs
+ for i in range(block_state.num_steps):
+ # loop_step executes all registered blocks in sequence
+ components, block_state = self.loop_step(components, block_state, i=i)
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+**Step 2: Create Loop Blocks**
+
+Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently:
+* It receives the iteration variable (e.g., `i`) passed by the loop wrapper
+* It works directly with `block_state` instead of pipeline state
+* No need to call `self.get_block_state()` or `self.set_block_state()`
+
+```py
+class LoopBlock(PipelineBlock):
+ # this is used to identify the model family, we won't worry about it in this example
+ model_name = "test"
+ @property
+ def inputs(self):
+ return [InputParam(name="x")]
+ @property
+ def intermediate_outputs(self):
+ # outputs produced by this block
+ return [OutputParam(name="x")]
+ @property
+ def description(self):
+ return "I'm a block used inside the `LoopWrapper` class"
+ def __call__(self, components, block_state, i: int):
+ block_state.x += 1
+ return components, block_state
+```
+
+**Step 3: Combine Everything**
+
+Finally, assemble your loop by adding the block(s) to the wrapper:
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
+```
+
+Now you've created a loop with one step:
+
+```py
+>>> loop
+LoopWrapper(
+ Class: LoopSequentialPipelineBlocks
+
+ Description: I'm a loop!!
+
+ Sub-Blocks:
+ [0] block1 (LoopBlock)
+ Description: I'm a block used inside the `LoopWrapper` class
+
+)
+```
+
+It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop.
+
+```py
+>>> print(loop.doc)
+class LoopWrapper
+
+ I'm a loop!!
+
+ Inputs:
+
+ x (`None`, *optional*):
+
+ num_steps (`None`, *optional*):
+
+ Outputs:
+
+ x (`None`):
+```
+
+**Running the Loop:**
+
+```py
+# run the loop
+loop_pipeline = loop.init_pipeline()
+x = loop_pipeline(num_steps=10, x=0, output="x")
+assert x == 10
+```
+
+**Adding Multiple Blocks:**
+
+We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration:
+
+```py
+loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
+loop_pipeline = loop.init_pipeline()
+x = loop_pipeline(num_steps=10, x=0, output="x")
+assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20
+```
+
+**Key Differences from SequentialPipelineBlocks:**
+
+The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself.
+
+The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice:
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep
+StableDiffusionXLDenoiseStep()
+```
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/modular_diffusers_states.md b/docs/source/en/modular_diffusers/modular_diffusers_states.md
new file mode 100644
index 0000000000..744089fcf6
--- /dev/null
+++ b/docs/source/en/modular_diffusers/modular_diffusers_states.md
@@ -0,0 +1,59 @@
+
+
+# PipelineState and BlockState
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+In Modular Diffusers, `PipelineState` and `BlockState` are the core data structures that enable blocks to communicate and share data. The concept is fundamental to understand how blocks interact with each other and the pipeline system.
+
+In the modular diffusers system, `PipelineState` acts as the global state container that all pipeline blocks operate on. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
+
+A `PipelineState` consists of two distinct states:
+
+- **The immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it.
+
+- **The mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them.
+
+Here's an example of what a `PipelineState` looks like:
+
+```py
+PipelineState(
+ inputs={
+ 'prompt': 'a cat'
+ 'guidance_scale': 7.0
+ 'num_inference_steps': 25
+ },
+ intermediates={
+ 'prompt_embeds': Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
+ 'negative_prompt_embeds': None
+ },
+)
+```
+
+Each pipeline blocks define what parts of that state they can read from and write to through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties. At run time, they gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes.
+
+For example, if a block defines an input `image`, inside the block's `__call__` method, the `BlockState` would contain:
+
+```py
+BlockState(
+ image:
+)
+```
+
+You can access the variables directly as attributes: `block_state.image`.
+
+We will explore more on how blocks interact with pipeline state through their `inputs`, `intermediate_inputs`, and `intermediate_outputs` properties, see the [PipelineBlock guide](./pipeline_block.md).
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/modular_pipeline.md b/docs/source/en/modular_diffusers/modular_pipeline.md
new file mode 100644
index 0000000000..55182b921f
--- /dev/null
+++ b/docs/source/en/modular_diffusers/modular_pipeline.md
@@ -0,0 +1,1237 @@
+
+
+# ModularPipeline
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+`ModularPipeline` is the main interface for end users to run pipelines in Modular Diffusers. It takes pipeline blocks and converts them into a runnable pipeline that can load models and execute the computation steps.
+
+In this guide, we will focus on how to build pipelines using the blocks we officially support at diffusers 🧨. We'll cover how to use predefined blocks and convert them into a `ModularPipeline` for execution.
+
+
+
+This guide shows you how to use predefined blocks. If you want to learn how to create your own pipeline blocks, see the [PipelineBlock guide](pipeline_block.md) for creating individual blocks, and the multi-block guides for connecting them together:
+- [SequentialPipelineBlocks](sequential_pipeline_blocks.md) (for linear workflows)
+- [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows)
+- [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional workflows)
+
+For information on how data flows through pipelines, see the [PipelineState and BlockState guide](modular_diffusers_states.md).
+
+
+
+
+## Create ModularPipelineBlocks
+
+In Modular Diffusers system, you build pipelines using Pipeline blocks. Pipeline Blocks are fundamental building blocks - they define what components, inputs/outputs, and computation logics are needed. They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. But they are just definitions and don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. We'll first learn how to create predefined blocks here before talking about how to run them using `ModularPipeline`.
+
+All pipeline blocks inherit from the base class `ModularPipelineBlocks`, including:
+
+- [`PipelineBlock`]: The most granular block - you define the input/output/components requirements and computation logic.
+- [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block.
+- [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models.
+- [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs.
+
+It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLTextEncoderStep
+
+text_encoder_block = StableDiffusionXLTextEncoderStep()
+```
+
+This is a single `PipelineBlock`. You'll see that this text encoder block uses 2 text_encoders, 2 tokenizers as well as a guider component. It takes user inputs such as `prompt` and `negative_prompt`, and return text embeddings outputs such as `prompt_embeds` and `negative_prompt_embeds`.
+
+```py
+>>> text_encoder_block
+StableDiffusionXLTextEncoderStep(
+ Class: PipelineBlock
+ Description: Text Encoder step that generate text_embeddings to guide the image generation
+ Components:
+ text_encoder (`CLIPTextModel`)
+ text_encoder_2 (`CLIPTextModelWithProjection`)
+ tokenizer (`CLIPTokenizer`)
+ tokenizer_2 (`CLIPTokenizer`)
+ guider (`ClassifierFreeGuidance`)
+ Configs:
+ force_zeros_for_empty_prompt (default: True)
+ Inputs:
+ prompt=None, prompt_2=None, negative_prompt=None, negative_prompt_2=None, cross_attention_kwargs=None, clip_skip=None
+ Intermediates:
+ - outputs: prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+)
+```
+
+More commonly, you need multiple blocks to build your workflow. You can create a `SequentialPipelineBlocks` using block class presets from 🧨 Diffusers. `TEXT2IMAGE_BLOCKS` is a dict containing all the blocks needed for text-to-image generation.
+
+```py
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+```
+
+This creates a `SequentialPipelineBlocks`. Unlike the `text_encoder_block` we saw earlier, this is a multi-block and its `sub_blocks` attribute contains a list of other blocks (text_encoder, input, set_timesteps, prepare_latents, prepare_added_con, denoise, decode). Its requirements for components, inputs, and intermediate inputs are combined from these blocks that compose it. At runtime, it executes its sub-blocks sequentially and passes the pipeline state from one block to another.
+
+```py
+>>> t2i_blocks
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ Description:
+
+
+ Components:
+ text_encoder (`CLIPTextModel`)
+ text_encoder_2 (`CLIPTextModelWithProjection`)
+ tokenizer (`CLIPTokenizer`)
+ tokenizer_2 (`CLIPTokenizer`)
+ guider (`ClassifierFreeGuidance`)
+ scheduler (`EulerDiscreteScheduler`)
+ unet (`UNet2DConditionModel`)
+ vae (`AutoencoderKL`)
+ image_processor (`VaeImageProcessor`)
+
+ Configs:
+ force_zeros_for_empty_prompt (default: True)
+
+ Sub-Blocks:
+ [0] text_encoder (StableDiffusionXLTextEncoderStep)
+ Description: Text Encoder step that generate text_embeddings to guide the image generation
+
+ [1] input (StableDiffusionXLInputStep)
+ Description: Input processing step that:
+ 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
+ 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
+
+ All input tensors are expected to have either batch_size=1 or match the batch_size
+ of prompt_embeds. The tensors will be duplicated across the batch dimension to
+ have a final batch_size of batch_size * num_images_per_prompt.
+
+ [2] set_timesteps (StableDiffusionXLSetTimestepsStep)
+ Description: Step that sets the scheduler's timesteps for inference
+
+ [3] prepare_latents (StableDiffusionXLPrepareLatentsStep)
+ Description: Prepare latents step that prepares the latents for the text-to-image generation process
+
+ [4] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
+ Description: Step that prepares the additional conditioning for the text-to-image generation process
+
+ [5] denoise (StableDiffusionXLDenoiseStep)
+ Description: Denoise step that iteratively denoise the latents.
+ Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
+ At each iteration, it runs blocks defined in `sub_blocks` sequencially:
+ - `StableDiffusionXLLoopBeforeDenoiser`
+ - `StableDiffusionXLLoopDenoiser`
+ - `StableDiffusionXLLoopAfterDenoiser`
+ This block supports both text2img and img2img tasks.
+
+ [6] decode (StableDiffusionXLDecodeStep)
+ Description: Step that decodes the denoised latents into images
+
+)
+```
+
+This is the block classes preset (`TEXT2IMAGE_BLOCKS`) we used: It is just a dictionary that maps names to ModularPipelineBlocks classes
+
+```py
+>>> TEXT2IMAGE_BLOCKS
+InsertableDict([
+ 0: ('text_encoder', ),
+ 1: ('input', ),
+ 2: ('set_timesteps', ),
+ 3: ('prepare_latents', ),
+ 4: ('prepare_add_cond', ),
+ 5: ('denoise', ),
+ 6: ('decode', )
+])
+```
+
+When we create a `SequentialPipelineBlocks` from this preset, it instantiates each block class into actual block objects. Its `sub_blocks` attribute now contains these instantiated objects:
+
+```py
+>>> t2i_blocks.sub_blocks
+InsertableDict([
+ 0: ('text_encoder', ),
+ 1: ('input', ),
+ 2: ('set_timesteps', ),
+ 3: ('prepare_latents', ),
+ 4: ('prepare_add_cond', ),
+ 5: ('denoise', ),
+ 6: ('decode', )
+])
+```
+
+Note that both the block classes preset and the `sub_blocks` attribute are `InsertableDict` objects. This is a custom dictionary that extends `OrderedDict` with the ability to insert items at specific positions. You can perform all standard dictionary operations (get, set, delete) plus insert items at any index, which is particularly useful for reordering or inserting blocks in the middle of a pipeline.
+
+**Add a block:**
+```py
+# BLOCKS is dict of block classes, you need to add class to it
+BLOCKS.insert("block_name", BlockClass, index)
+# sub_blocks attribute contains instance, add a block instance to the attribute
+t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
+```
+
+**Remove a block:**
+```py
+# remove a block class from preset
+BLOCKS.pop("text_encoder")
+# split out a block instance on its own
+text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
+```
+
+**Swap block:**
+```py
+# Replace block class in preset
+BLOCKS["prepare_latents"] = CustomPrepareLatents
+# Replace in sub_blocks attribute using an block instance
+t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
+```
+
+This means you can mix-and-match blocks in very flexible ways. Let's see some real examples:
+
+**Example 1: Adding IP-Adapter to the Block Classes Preset**
+Let's make a new block classes preset by insert IP-Adapter at index 0 (before the text_encoder block), and create a text-to-image pipeline with IP-Adapter support:
+
+```py
+from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
+CUSTOM_BLOCKS = TEXT2IMAGE_BLOCKS.copy()
+# CUSTOM_BLOCKS is now a preset including ip_adapter
+CUSTOM_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
+# create a blocks isntance from the preset
+custom_blocks = SequentialPipelineBlocks.from_blocks_dict(CUSTOM_BLOCKS)
+```
+
+**Example 2: Extracting a block from a multi-block**
+You can extract a block instance from the multi-block to use it independently. A common pattern is to use text_encoder to process prompts once, then reuse the text embeddings outputs to generate multiple images with different settings (schedulers, seeds, inference steps). We can do this by simply extracting the text_encoder block from the pipeline.
+
+```py
+# this gives you StableDiffusionXLTextEncoderStep()
+>>> text_encoder_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
+>>> text_encoder_blocks
+```
+
+The multi-block now has fewer components and no longer has the `text_encoder` block. If you check its docstring `t2i_blocks.doc`, you will see that it no longer accepts `prompt` as input - you will need to pass the embeddings instead.
+
+```py
+>>> t2i_blocks
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ Description:
+
+ Components:
+ scheduler (`EulerDiscreteScheduler`)
+ guider (`ClassifierFreeGuidance`)
+ unet (`UNet2DConditionModel`)
+ vae (`AutoencoderKL`)
+ image_processor (`VaeImageProcessor`)
+
+ Blocks:
+ [0] input (StableDiffusionXLInputStep)
+ Description: Input processing step that:
+ 1. Determines `batch_size` and `dtype` based on `prompt_embeds`
+ 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
+
+ All input tensors are expected to have either batch_size=1 or match the batch_size
+ of prompt_embeds. The tensors will be duplicated across the batch dimension to
+ have a final batch_size of batch_size * num_images_per_prompt.
+
+ [1] set_timesteps (StableDiffusionXLSetTimestepsStep)
+ Description: Step that sets the scheduler's timesteps for inference
+
+ [2] prepare_latents (StableDiffusionXLPrepareLatentsStep)
+ Description: Prepare latents step that prepares the latents for the text-to-image generation process
+
+ [3] prepare_add_cond (StableDiffusionXLPrepareAdditionalConditioningStep)
+ Description: Step that prepares the additional conditioning for the text-to-image generation process
+
+ [4] denoise (StableDiffusionXLDenoiseLoop)
+ Description: Denoise step that iteratively denoise the latents.
+ Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
+ At each iteration, it runs blocks defined in `blocks` sequencially:
+ - `StableDiffusionXLLoopBeforeDenoiser`
+ - `StableDiffusionXLLoopDenoiser`
+ - `StableDiffusionXLLoopAfterDenoiser`
+
+
+ [5] decode (StableDiffusionXLDecodeStep)
+ Description: Step that decodes the denoised latents into images
+
+)
+```
+
+
+
+💡 You can find all the block classes presets we support for each model in `ALL_BLOCKS`.
+
+```py
+# For Stable Diffusion XL
+from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
+ALL_BLOCKS
+# For other models...
+from diffusers.modular_pipelines. import ALL_BLOCKS
+```
+
+Each model provides a dictionary that maps all supported tasks/techniques to their corresponding block classes presets. For SDXL, it is
+
+```py
+ALL_BLOCKS = {
+ "text2img": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "ip_adapter": IP_ADAPTER_BLOCKS,
+ "auto": AUTO_BLOCKS,
+}
+```
+
+
+
+This covers the essentials of pipeline blocks! Like we have already mentioned, **pipeline blocks are not runnable by themselves**. They are essentially **"definitions"** - they define the specifications and computational steps for a pipeline, but they do not contain any model states. To actually run them, you need to convert them into a `ModularPipeline` object.
+
+
+## Modular Repo
+
+To convert blocks into a runnable pipeline, you may need a repository if your blocks contain **pretrained components** (models with checkpoints that need to be loaded from the Hub). Pipeline blocks define what components they need (like a UNet, text encoder, etc.), as well as how to create them: components can be either created using **from_pretrained** method (with checkpoints) or **from_config** (initialized from scratch with default configuration, usually stateless like a guider or scheduler).
+
+If your pipeline contains **pretrained components**, you typically need to use a repository to provide the loading specifications and metadata.
+
+`ModularPipeline` works specifically with modular repositories, which offer more flexibility in component loading compared to traditional repositories. You can find an example modular repo [here](https://huggingface.co/YiYiXu/modular-diffdiff).
+
+A `DiffusionPipeline` defines `model_index.json` to configure its components. However, repositories for Modular Diffusers work with `modular_model_index.json`. Let's walk through the differences here.
+
+In standard `model_index.json`, each component entry is a `(library, class)` tuple:
+```py
+"text_encoder": [
+ "transformers",
+ "CLIPTextModel"
+],
+```
+
+In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs_dict)`
+
+- `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded)
+- `loading_specs_dict`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`.
+
+```py
+"text_encoder": [
+ null, # library of actual loaded component (same as in model_index.json)
+ null, # class of actual loaded componenet (same as in model_index.json)
+ { # loading specs map (unique to modular_model_index.json)
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0", # can be a different repo
+ "revision": null,
+ "subfolder": "text_encoder",
+ "type_hint": [ # (library, class) for the expected component
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "variant": null
+ }
+],
+```
+
+Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories.
+
+
+## Creating a `ModularPipeline` from `ModularPipelineBlocks`
+
+Each `ModularPipelineBlocks` has an `init_pipeline` method that can initialize a `ModularPipeline` object based on its component and configuration specifications.
+
+Let's convert our `t2i_blocks` (which we created earlier) into a runnable `ModularPipeline`. We'll use a `ComponentsManager` to handle device placement, memory management, and component reuse automatically:
+
+```py
+# We already have this from earlier
+t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+# Now convert it to a ModularPipeline
+from diffusers import ComponentsManager
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+components = ComponentsManager()
+t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
+```
+
+
+
+💡 **ComponentsManager** is the model registry and management system in diffusers, it track all the models in one place and let you add, remove and reuse them across different workflows in most efficient way. Without it, you'd need to manually manage GPU memory, device placement, and component sharing between workflows. See the [Components Manager guide](components_manager.md) for detailed information.
+
+
+
+The `init_pipeline()` method creates a ModularPipeline and loads component specifications from the repository's `modular_model_index.json` file, but doesn't load the actual models yet.
+
+
+## Creating a `ModularPipeline` with `from_pretrained`
+
+You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo:
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+components = ComponentsManager()
+pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
+```
+
+Loading custom code is also supported:
+
+```py
+from diffusers import ModularPipeline, ComponentsManager
+components = ComponentsManager()
+modular_repo_id = "YiYiXu/modular-diffdiff-0704"
+diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
+```
+
+This modular repository contains custom code. The folder contains these files:
+
+```
+modular-diffdiff-0704/
+├── block.py # Custom pipeline blocks implementation
+├── config.json # Pipeline configuration and auto_map
+└── modular_model_index.json # Component loading specifications
+```
+
+The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation:
+
+```json
+{
+ "_class_name": "DiffDiffBlocks",
+ "auto_map": {
+ "ModularPipelineBlocks": "block.DiffDiffBlocks"
+ }
+}
+```
+
+The `auto_map` tells the pipeline where to find the custom blocks definition - in this case, it's looking for `DiffDiffBlocks` in the `block.py` file. The actual `DiffDiffBlocks` class is defined in [`block.py`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/block.py) within the repository.
+
+When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library.
+
+## Loading components into a `ModularPipeline`
+
+Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`:
+
+```py
+# This will load ALL the expected components into pipeline
+import torch
+t2i_pipeline.load_default_components(torch_dtype=torch.float16)
+t2i_pipeline.to("cuda")
+```
+
+All expected components are now loaded into the pipeline. You can also partially load specific components using the `names` argument. For example, to only load unet and vae:
+
+```py
+>>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
+```
+
+You can inspect the pipeline's loading status by simply printing the pipeline itself. It helps you understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. Let's print out the `t2i_pipeline`:
+
+```py
+>>> t2i_pipeline
+StableDiffusionXLModularPipeline {
+ "_blocks_class_name": "SequentialPipelineBlocks",
+ "_class_name": "StableDiffusionXLModularPipeline",
+ "_diffusers_version": "0.35.0.dev0",
+ "force_zeros_for_empty_prompt": true,
+ "scheduler": [
+ null,
+ null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "scheduler",
+ "type_hint": [
+ "diffusers",
+ "EulerDiscreteScheduler"
+ ],
+ "variant": null
+ }
+ ],
+ "text_encoder": [
+ null,
+ null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "text_encoder",
+ "type_hint": [
+ "transformers",
+ "CLIPTextModel"
+ ],
+ "variant": null
+ }
+ ],
+ "text_encoder_2": [
+ null,
+ null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "text_encoder_2",
+ "type_hint": [
+ "transformers",
+ "CLIPTextModelWithProjection"
+ ],
+ "variant": null
+ }
+ ],
+ "tokenizer": [
+ null,
+ null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "tokenizer",
+ "type_hint": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "variant": null
+ }
+ ],
+ "tokenizer_2": [
+ null,
+ null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "tokenizer_2",
+ "type_hint": [
+ "transformers",
+ "CLIPTokenizer"
+ ],
+ "variant": null
+ }
+ ],
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel",
+ {
+ "repo": "RunDiffusion/Juggernaut-XL-v9",
+ "revision": null,
+ "subfolder": "unet",
+ "type_hint": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "variant": "fp16"
+ }
+ ],
+ "vae": [
+ "diffusers",
+ "AutoencoderKL",
+ {
+ "repo": "madebyollin/sdxl-vae-fp16-fix",
+ "revision": null,
+ "subfolder": null,
+ "type_hint": [
+ "diffusers",
+ "AutoencoderKL"
+ ],
+ "variant": null
+ }
+ ]
+}
+```
+
+You can see all the **pretrained components** that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`:
+
+- **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet.
+- **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.)
+
+In this example:
+- **Loaded components**: `vae` and `unet` (their `library` and `class` fields show the actual loaded models)
+- **Not loaded yet**: `scheduler`, `text_encoder`, `text_encoder_2`, `tokenizer`, `tokenizer_2` (their `library` and `class` fields are `null`, but you can see their loading specs to know where they'll be loaded from when you call `load_components()`)
+
+You're looking at essentailly the pipeline's config dict that's synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements.
+
+For example, if your pipeline needs a `text_encoder` component, it will include the loading spec for `text_encoder` from the modular repo during the `init_pipeline`. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't be included even if it exists in the modular repo.
+
+There are also a few properties that can provide a quick summary of component loading status:
+
+```py
+# All components expected by the pipeline
+>>> t2i_pipeline.component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
+
+# Components that are not loaded yet (will be loaded with from_pretrained)
+>>> t2i_pipeline.null_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
+
+# Components that will be loaded from pretrained models
+>>> t2i_pipeline.pretrained_component_names
+['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
+
+# Components that are created with default config (no repo needed)
+>>> t2i_pipeline.config_component_names
+['guider', 'image_processor']
+```
+
+From config components (like `guider` and `image_processor`) are not included in the pipeline output above because they don't need loading specs - they're already initialized during pipeline creation. You can see this because they're not listed in `null_component_names`.
+
+## Modifying Loading Specs
+
+When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by modifying the `modular_model_index.json` in the repository. Just find the file on the Hub and click edit - you can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc.
+
+```py
+# Original spec in modular_model_index.json
+"unet": [
+ null, null,
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+
+# Modified spec - changed repo, subfolder, and variant
+"unet": [
+ null, null,
+ {
+ "repo": "RunDiffusion/Juggernaut-XL-v9",
+ "subfolder": "unet",
+ "variant": "fp16"
+ }
+]
+```
+
+Now if you create a pipeline using the same blocks and updated repository, it will by default load from the new repository.
+
+```py
+pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
+pipeline.load_components(names="unet")
+```
+
+
+## Updating components in a `ModularPipeline`
+
+Similar to `DiffusionPipeline`, you can load components separately to replace the default ones in the pipeline. In Modular Diffusers, the approach depends on the component type:
+
+- **Pretrained components** (`default_creation_method='from_pretrained'`): Must use `ComponentSpec` to load them to update the existing one.
+- **Config components** (`default_creation_method='from_config'`): These are components that don't need loading specs - they're created during pipeline initialization with default config. To update them, you can either pass the object directly or pass a ComponentSpec directly.
+
+
+
+💡 **Component Type Changes**: The component type (pretrained vs config-based) can change when you update components. These types are initially defined in pipeline blocks' `expected_components` field using `ComponentSpec` with `default_creation_method`. See the [Customizing Guidance Techniques](#customizing-guidance-techniques) section for examples of how this works in practice.
+
+
+
+`ComponentSpec` defines how to create or load components and can actually create them using its `create()` method (for ConfigMixin objects) or `load()` method (wrapper around `from_pretrained()`). When a component is loaded with a ComponentSpec, it gets tagged with a unique ID that encodes its creation parameters, allowing you to always extract the original specification using `ComponentSpec.from_component()`.
+
+Now let's look at how to update pretrained components in practice:
+
+So instead of
+
+```py
+from diffusers import UNet2DConditionModel
+import torch
+unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16", torch_dtype=torch.float16)
+```
+You should load your model like this
+
+```py
+from diffusers import ComponentSpec, UNet2DConditionModel
+unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
+unet2 = unet_spec.load(torch_dtype=torch.float16)
+```
+
+The key difference is that the second unet retains its loading specs, so you can extract the spec and recreate the unet:
+
+```py
+# component -> spec
+>>> spec = ComponentSpec.from_component("unet", unet2)
+>>> spec
+ComponentSpec(name='unet', type_hint=, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
+# spec -> component
+>>> unet2_recreatd = spec.load(torch_dtype=torch.float16)
+```
+
+To replace the unet in the pipeline
+
+```
+t2i_pipeline.update_components(unet=unet2)
+```
+
+Not only is the `unet` component swapped, but its loading specs are also updated from "RunDiffusion/Juggernaut-XL-v9" to "stabilityai/stable-diffusion-xl-base-1.0" in pipeline config. This means that if you save the pipeline now and load it back with `from_pretrained`, the new pipeline will by default load the SDXL original unet.
+
+```
+>>> t2i_pipeline
+StableDiffusionXLModularPipeline {
+ ...
+ "unet": [
+ "diffusers",
+ "UNet2DConditionModel",
+ {
+ "repo": "stabilityai/stable-diffusion-xl-base-1.0",
+ "revision": null,
+ "subfolder": "unet",
+ "type_hint": [
+ "diffusers",
+ "UNet2DConditionModel"
+ ],
+ "variant": "fp16"
+ }
+ ],
+ ...
+}
+```
+
+
+💡 **Modifying Component Specs**: You can get a copy of the current component spec from the pipeline using `get_component_spec()`. This makes it easy to modify the spec and updating components.
+
+```py
+>>> unet_spec = t2i_pipeline.get_component_spec("unet")
+>>> unet_spec
+ComponentSpec(
+ name='unet',
+ type_hint=,
+ repo='RunDiffusion/Juggernaut-XL-v9',
+ subfolder='unet',
+ variant='fp16',
+ default_creation_method='from_pretrained'
+)
+
+# Modify the spec to load from a different repository
+>>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0"
+
+# Load the component with the modified spec
+>>> unet = unet_spec.load(torch_dtype=torch.float16)
+```
+
+
+
+## Customizing Guidance Techniques
+
+Guiders are implementations of different [classifier-free guidance](https://huggingface.co/papers/2207.12598) techniques that can be applied during the denoising process to improve generation quality, control, and adherence to prompts. They work by steering the model predictions towards desired directions and away from undesired directions. In diffusers, guiders are implemented as subclasses of `BaseGuidance`. They can easily be integrated into modular pipelines and provide a flexible way to enhance generation quality without modifying the underlying diffusion models.
+
+**ClassifierFreeGuidance (CFG)** is the first and most common guidance technique, used in all our standard pipelines. We also offer many other guidance techniques from the latest research in this area - **PerturbedAttentionGuidance (PAG)**, **SkipLayerGuidance (SLG)**, **SmoothedEnergyGuidance (SEG)**, and others that can provide better results for specific use cases.
+
+This section demonstrates how to use guiders using the component updating methods we just learned. Since `BaseGuidance` components are stateless (similar to schedulers), they are typically created with default configurations during pipeline initialization using `default_creation_method='from_config'`. This means they don't require loading specs from the repository - you won't see guider listed in `modular_model_index.json` files.
+
+Let's take a look at the default guider configuration:
+
+```py
+>>> t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 7.5), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['start', 'guidance_rescale', 'stop', 'use_original_formulation'])]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+As you can see, the guider is configured to use `ClassifierFreeGuidance` with default parameters and `default_creation_method='from_config'`, meaning it's created during pipeline initialization rather than loaded from a repository. Let's verify this, here we run `init_pipeline()` without a modular repo, and there it is, a guider with the default configuration we just saw
+
+
+```py
+>>> pipeline = t2i_blocks.init_pipeline()
+>>> pipeline.guider
+ClassifierFreeGuidance {
+ "_class_name": "ClassifierFreeGuidance",
+ "_diffusers_version": "0.35.0.dev0",
+ "guidance_rescale": 0.0,
+ "guidance_scale": 7.5,
+ "start": 0.0,
+ "stop": 1.0,
+ "use_original_formulation": false
+}
+```
+
+#### Modify Parameters of the Same Guider Type
+
+To change parameters of the same guider type (e.g., adjusting the `guidance_scale` for CFG), you have two options:
+
+**Option 1: Use ComponentSpec.create() method**
+
+You just need to pass the parameter with the new value to override the default one.
+
+```python
+>>> guider_spec = t2i_pipeline.get_component_spec("guider")
+>>> guider = guider_spec.create(guidance_scale=10)
+>>> t2i_pipeline.update_components(guider=guider)
+```
+
+**Option 2: Pass ComponentSpec directly**
+
+Update the spec directly and pass it to `update_components()`.
+
+```python
+>>> guider_spec = t2i_pipeline.get_component_spec("guider")
+>>> guider_spec.config["guidance_scale"] = 10
+>>> t2i_pipeline.update_components(guider=guider_spec)
+```
+
+Both approaches produce the same result:
+```python
+>>> t2i_pipeline.guider
+ClassifierFreeGuidance {
+ "_class_name": "ClassifierFreeGuidance",
+ "_diffusers_version": "0.35.0.dev0",
+ "guidance_rescale": 0.0,
+ "guidance_scale": 10,
+ "start": 0.0,
+ "stop": 1.0,
+ "use_original_formulation": false
+}
+```
+
+#### Switch to a Different Guider Type
+
+Switching between guidance techniques is as simple as passing a guider object of that technique:
+
+```py
+from diffusers import LayerSkipConfig, PerturbedAttentionGuidance
+config = LayerSkipConfig(indices=[2, 9], fqn="mid_block.attentions.0.transformer_blocks", skip_attention=False, skip_attention_scores=True, skip_ff=False)
+guider = PerturbedAttentionGuidance(
+ guidance_scale=5.0, perturbed_guidance_scale=2.5, perturbed_guidance_config=config
+)
+t2i_pipeline.update_components(guider=guider)
+```
+
+Note that you will get a warning about changing the guider type, which is expected:
+
+```
+ModularPipeline.update_components: adding guider with new type: PerturbedAttentionGuidance, previous type: ClassifierFreeGuidance
+```
+
+
+
+- For `from_config` components (like guiders, schedulers): You can pass an object of required type OR pass a ComponentSpec directly (which calls `create()` under the hood)
+- For `from_pretrained` components (like models): You must use ComponentSpec to ensure proper tagging and loading
+
+
+
+Let's verify that the guider has been updated:
+
+```py
+>>> t2i_pipeline.guider
+PerturbedAttentionGuidance {
+ "_class_name": "PerturbedAttentionGuidance",
+ "_diffusers_version": "0.35.0.dev0",
+ "guidance_rescale": 0.0,
+ "guidance_scale": 5.0,
+ "perturbed_guidance_config": {
+ "dropout": 1.0,
+ "fqn": "mid_block.attentions.0.transformer_blocks",
+ "indices": [
+ 2,
+ 9
+ ],
+ "skip_attention": false,
+ "skip_attention_scores": true,
+ "skip_ff": false
+ },
+ "perturbed_guidance_layers": null,
+ "perturbed_guidance_scale": 2.5,
+ "perturbed_guidance_start": 0.01,
+ "perturbed_guidance_stop": 0.2,
+ "start": 0.0,
+ "stop": 1.0,
+ "use_original_formulation": false
+}
+
+```
+
+The component spec has also been updated to reflect the new guider type:
+
+```py
+>>> t2i_pipeline.get_component_spec("guider")
+ComponentSpec(name='guider', type_hint=, description=None, config=FrozenDict([('guidance_scale', 5.0), ('perturbed_guidance_scale', 2.5), ('perturbed_guidance_start', 0.01), ('perturbed_guidance_stop', 0.2), ('perturbed_guidance_layers', None), ('perturbed_guidance_config', LayerSkipConfig(indices=[2, 9], fqn='mid_block.attentions.0.transformer_blocks', skip_attention=False, skip_attention_scores=True, skip_ff=False, dropout=1.0)), ('guidance_rescale', 0.0), ('use_original_formulation', False), ('start', 0.0), ('stop', 1.0), ('_use_default_values', ['perturbed_guidance_start', 'use_original_formulation', 'perturbed_guidance_layers', 'stop', 'start', 'guidance_rescale', 'perturbed_guidance_stop']), ('_class_name', 'PerturbedAttentionGuidance'), ('_diffusers_version', '0.35.0.dev0')]), repo=None, subfolder=None, variant=None, revision=None, default_creation_method='from_config')
+```
+
+The "guider" is still a `from_config` component: is still not included in the pipeline config and will not be saved into the `modular_model_index.json`.
+
+```py
+>>> assert "guider" not in t2i_pipeline.config
+```
+
+However, you can change it to a `from_pretrained` component, which allows you to upload your customized guider to the Hub and load it into your pipeline.
+
+#### Loading Custom Guiders from Hub
+
+If you already have a guider saved on the Hub and a `modular_model_index.json` with the loading spec for that guider, it will automatically be changed to a `from_pretrained` component during pipeline initialization.
+
+For example, this `modular_model_index.json` includes loading specs for the guider:
+
+```json
+{
+ "guider": [
+ null,
+ null,
+ {
+ "repo": "YiYiXu/modular-loader-t2i-guider",
+ "revision": null,
+ "subfolder": "pag_guider",
+ "type_hint": [
+ "diffusers",
+ "PerturbedAttentionGuidance"
+ ],
+ "variant": null
+ }
+ ]
+}
+```
+
+When you use this repository to create a pipeline with the same blocks (that originally configured guider as a `from_config` component), the guider becomes a `from_pretrained` component. This means it doesn't get created during initialization, and after you call `load_default_components()`, it loads based on the spec - resulting in the PAG guider instead of the default CFG.
+
+```py
+t2i_pipeline = t2i_blocks.init_pipeline("YiYiXu/modular-doc-guider")
+assert t2i_pipeline.guider is None # Not created during init
+t2i_pipeline.load_default_components()
+t2i_pipeline.guider # Now loaded as PAG guider
+```
+
+#### Upload Custom Guider to Hub for Easy Loading & Sharing
+
+Now let's see how we can share the guider on the Hub and change it to a `from_pretrained` component.
+
+```py
+guider.push_to_hub("YiYiXu/modular-loader-t2i-guider", subfolder="pag_guider")
+```
+
+Voilà! Now you have a subfolder called `pag_guider` on that repository.
+
+You have a few options to make this guider available in your pipeline:
+
+1. **Directly modify the `modular_model_index.json`** to add a loading spec for the guider by pointing to a folder containing the desired guider config.
+
+2. **Use the `update_components` method** to change it to a `from_pretrained` component for your pipeline. This is easier if you just want to try it out with different repositories.
+
+Let's use the second approach and change our guider_spec to use `from_pretrained` as the default creation method and update the loading spec to use this subfolder we just created:
+
+```python
+guider_spec = t2i_pipeline.get_component_spec("guider")
+guider_spec.default_creation_method="from_pretrained"
+guider_spec.repo="YiYiXu/modular-loader-t2i-guider"
+guider_spec.subfolder="pag_guider"
+pag_guider = guider_spec.load()
+t2i_pipeline.update_components(guider=pag_guider)
+```
+
+You will get a warning about changing the creation method:
+
+```
+ModularPipeline.update_components: changing the default_creation_method of guider from from_config to from_pretrained.
+```
+
+Now not only the `guider` component and its component_spec are updated, but so is the pipeline config.
+
+If you want to change the default behavior for future pipelines, you can push the updated pipeline to the Hub. This way, when others use your repository, they'll get the PAG guider by default. However, this is optional - you don't have to do this if you just want to experiment locally.
+
+```py
+t2i_pipeline.push_to_hub("YiYiXu/modular-doc-guider")
+```
+
+
+
+
+Experiment with different techniques and parameters to find what works best for your specific use case! You can find all the guider class we support [here](TODO: API doc)
+
+Additionally, you can write your own guider implementations, for example, CFG Zero* combined with Skip Layer Guidance, and they should be compatible out-of-the-box with modular diffusers!
+
+
+
+## Running a `ModularPipeline`
+
+The API to run the `ModularPipeline` is very similar to how you would run a regular `DiffusionPipeline`:
+
+```py
+>>> image = pipeline(prompt="a cat", num_inference_steps=15, output="images")[0]
+```
+
+There are a few key differences though:
+1. You can also pass a `PipelineState` object directly to the pipeline instead of individual arguments
+2. If you do not specify the `output` argument, it returns the `PipelineState` object
+3. You can pass a list as `output`, e.g. `pipeline(... output=["images", "latents"])` will return a dictionary containing both the generated image and the final denoised latents
+
+Under the hood, `ModularPipeline`'s `__call__` method is a wrapper around the pipeline blocks' `__call__` method: it creates a `PipelineState` object and populates it with user inputs, then returns the output to the user based on the `output` argument. It also ensures that all pipeline-level config and components are exposed to all pipeline blocks by preparing and passing a `components` input.
+
+
+
+You can inspect the docstring of a `ModularPipeline` to check what arguments the pipeline accepts and how to specify the `output` you want. It will list all available outputs (basically everything in the intermediate pipeline state) so you can choose from the list.
+
+```py
+t2i_pipeline.doc
+```
+
+**Important**: It is important to always check the docstring because arguments can be different from standard pipelines that you're familar with. For example, in Modular Diffusers we standardized controlnet image input as `control_image`, but regular pipelines have inconsistencies over the names, e.g. controlnet text-to-image uses `image` while SDXL controlnet img2img uses `control_image`.
+
+**Note**: The `output` list might be longer than you expected - it includes everything in the intermediate state that you can choose to return. Most of the time, you'll just want `output="images"` or `output="latents"`.
+
+
+
+#### Text-to-Image, Image-to-Image, and Inpainting
+
+These are minimum inference examples for basic tasks: text-to-image, image-to-image, and inpainting. The process to create different pipelines is the same - only difference is the block classes presets. The inference is also more or less same to standard pipelines, but please always check `.doc` for correct input names and remember to pass `output="images"`.
+
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+# create pipeline from official blocks preset
+blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+# run pipeline, need to pass a "output=images" argument
+image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
+image.save("modular_t2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
+
+# create pipeline from blocks preset
+blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+init_image = load_image(url)
+prompt = "a dog catching a frisbee in the jungle"
+image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
+image.save("modular_i2i_out.png")
+```
+
+
+
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
+from diffusers.utils import load_image
+
+# create pipeline from blocks preset
+blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
+
+modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
+mask_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-inpaint-mask.png"
+
+init_image = load_image(img_url)
+mask_image = load_image(mask_url)
+
+prompt = "A deep sea diver floating"
+image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
+image.save("moduar_inpaint_out.png")
+```
+
+
+
+
+#### ControlNet
+
+For ControlNet, we provide one auto block you can place at the `denoise` step. Let's create it and inspect it to see what it tells us.
+
+
+
+💡 **How to explore new tasks**: When you want to figure out how to do a specific task in Modular Diffusers, it is a good idea to start by checking what block classes presets we offer in `ALL_BLOCKS`. Then create the block instance and inspect it - it will show you the required components, description, and sub-blocks. This is crucial for understanding what each block does and what it needs.
+
+
+
+```py
+>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
+>>> ALL_BLOCKS["controlnet"]
+InsertableDict([
+ 0: ('denoise', )
+])
+>>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]()
+>>> controlnet_blocks
+StableDiffusionXLAutoControlnetStep(
+ Class: SequentialPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: {'mask', 'control_mode', 'control_image', 'controlnet_cond'}
+ Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('mask')`).
+ ====================================================================================================
+
+
+ Description: Controlnet auto step that prepare the controlnet input and denoise the latents. It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks. (it should be replace at 'denoise' step)
+
+
+ Components:
+ controlnet (`ControlNetUnionModel`)
+ control_image_processor (`VaeImageProcessor`)
+ scheduler (`EulerDiscreteScheduler`)
+ unet (`UNet2DConditionModel`)
+ guider (`ClassifierFreeGuidance`)
+
+ Sub-Blocks:
+ [0] controlnet_input (StableDiffusionXLAutoControlNetInputStep)
+ Description: Controlnet Input step that prepare the controlnet input.
+ This is an auto pipeline block that works for both controlnet and controlnet_union.
+ (it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
+ - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
+
+ [1] controlnet_denoise (StableDiffusionXLAutoControlNetDenoiseStep)
+ Description: Denoise step that iteratively denoise the latents with controlnet. This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks.This block should not be used without a controlnet_cond input - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided. - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided. - If neither mask nor controlnet_cond are provided, step will be skipped.
+
+)
+```
+
+
+
+💡 **Auto Blocks**: This is first time we meet a Auto Blocks! `AutoPipelineBlocks` automatically adapt to your inputs by combining multiple workflows with conditional logic. This is why one convenient block can work for all tasks and controlnet types. See the [Auto Blocks Guide](./auto_pipeline_blocks.md) for more details.
+
+
+
+The block shows us it has two steps (prepare inputs + denoise) and supports all tasks with both controlnet and controlnet union. Most importantly, it tells us to place it at the 'denoise' step. Let's do exactly that:
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS, StableDiffusionXLAutoControlnetStep
+from diffusers.utils import load_image
+
+# create pipeline from blocks preset
+blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+# these two lines applies controlnet
+controlnet_blocks = StableDiffusionXLAutoControlnetStep()
+blocks.sub_blocks["denoise"] = controlnet_blocks
+```
+
+Before we convert the blocks into a pipeline and load its components, let's inspect the blocks and its docs again to make sure it was assembled correctly. You should be able to see that `controlnet` and `control_image_processor` are now listed as `Components`, so we should initialize the pipeline with a repo that contains desired loading specs for these 2 components.
+
+```py
+# make sure to a modular_repo including controlnet
+modular_repo_id = "YiYiXu/modular-demo-auto"
+pipeline = blocks.init_pipeline(modular_repo_id)
+pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.to("cuda")
+
+# generate
+canny_image = load_image(
+ "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
+)
+image = pipeline(
+ prompt="a bird", controlnet_conditioning_scale=0.5, control_image=canny_image, output="images"
+)[0]
+image.save("modular_control_out.png")
+```
+
+#### IP-Adapter
+
+**Challenge time!** Before we show you how to apply IP-adapter, try doing it yourself! Use the same process we just walked you through with ControlNet: check the official blocks preset, inspect the block instance and docstring `.doc`, and adapt a regular IP-adapter example to modular.
+
+Let's walk through the steps:
+
+1. Check blocks preset
+
+```py
+>>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
+>>> ALL_BLOCKS["ip_adapter"]
+InsertableDict([
+ 0: ('ip_adapter', )
+])
+```
+
+2. inspect the block & doc
+
+```
+>>> from diffusers.modular_pipelines.stable_diffusion_xl import StableDiffusionXLAutoIPAdapterStep
+>>> ip_adapter_blocks = StableDiffusionXLAutoIPAdapterStep()
+>>> ip_adapter_blocks
+StableDiffusionXLAutoIPAdapterStep(
+ Class: AutoPipelineBlocks
+
+ ====================================================================================================
+ This pipeline contains blocks that are selected at runtime based on inputs.
+ Trigger Inputs: {'ip_adapter_image'}
+ Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
+ ====================================================================================================
+
+
+ Description: Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.
+
+
+
+ Components:
+ image_encoder (`CLIPVisionModelWithProjection`)
+ feature_extractor (`CLIPImageProcessor`)
+ unet (`UNet2DConditionModel`)
+ guider (`ClassifierFreeGuidance`)
+
+ Sub-Blocks:
+ • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep)
+ Description: IP Adapter step that prepares ip adapter image embeddings.
+ Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().
+ See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details
+
+)
+```
+3. follow the instruction to build
+
+```py
+import torch
+from diffusers.modular_pipelines import SequentialPipelineBlocks
+from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
+
+# create pipeline from official blocks preset
+blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
+
+# insert ip_adapter_blocks before the input step as instructed
+blocks.sub_blocks.insert("ip_adapter", ip_adapter_blocks, 1)
+
+# inspec the blocks before you convert it into pipelines,
+# and make sure to use a repo that contains the loading spec for all components
+# for ip-adapter, you need image_encoder & feature_extractor
+modular_repo_id = "YiYiXu/modular-demo-auto"
+pipeline = blocks.init_pipeline(modular_repo_id)
+
+pipeline.load_default_components(torch_dtype=torch.float16)
+pipeline.load_ip_adapter(
+ "h94/IP-Adapter",
+ subfolder="sdxl_models",
+ weight_name="ip-adapter_sdxl.bin"
+)
+pipeline.set_ip_adapter_scale(0.8)
+pipeline.to("cuda")
+```
+
+4. adapt an example to modular
+
+We are using [this one](https://huggingface.co/docs/diffusers/using-diffusers/ip_adapter?ipadapter-variants=IP-Adapter+Plus#ip-adapter) from our IP-Adapter doc!
+
+
+```py
+from diffusers.utils import load_image
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")
+image = pipeline(
+ prompt="a polar bear sitting in a chair drinking a milkshake",
+ ip_adapter_image=image,
+ negative_prompt="deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
+ output="images"
+)[0]
+image.save("modular_ipa_out.png")
+```
+
+
diff --git a/docs/source/en/modular_diffusers/overview.md b/docs/source/en/modular_diffusers/overview.md
new file mode 100644
index 0000000000..9702cea063
--- /dev/null
+++ b/docs/source/en/modular_diffusers/overview.md
@@ -0,0 +1,42 @@
+
+
+# Getting Started with Modular Diffusers
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+With Modular Diffusers, we introduce a unified pipeline system that simplifies how you work with diffusion models. Instead of creating separate pipelines for each task, Modular Diffusers lets you:
+
+**Write Only What's New**: You won't need to write an entire pipeline from scratch every time you have a new use case. You can create pipeline blocks just for your new workflow's unique aspects and reuse existing blocks for existing functionalities.
+
+**Assemble Like LEGO®**: You can mix and match between blocks in flexible ways. This allows you to write dedicated blocks unique to specific workflows, and then assemble different blocks into a pipeline that can be used more conveniently for multiple workflows.
+
+
+Here's how our guides are organized to help you navigate the Modular Diffusers documentation:
+
+### 🚀 Running Pipelines
+- **[Modular Pipeline Guide](./modular_pipeline.md)** - How to use predefined blocks to build a pipeline and run it
+- **[Components Manager Guide](./components_manager.md)** - How to manage and reuse components across multiple pipelines
+
+### 📚 Creating PipelineBlocks
+- **[Pipeline and Block States](./modular_diffusers_states.md)** - Understanding PipelineState and BlockState
+- **[Pipeline Block](./pipeline_block.md)** - How to write custom PipelineBlocks
+- **[SequentialPipelineBlocks](sequential_pipeline_blocks.md)** - Connecting blocks in sequence
+- **[LoopSequentialPipelineBlocks](./loop_sequential_pipeline_blocks.md)** - Creating iterative workflows
+- **[AutoPipelineBlocks](./auto_pipeline_blocks.md)** - Conditional block selection
+
+### 🎯 Practical Examples
+- **[End-to-End Example](./end_to_end_guide.md)** - Complete end-to-end examples including sharing your workflow in huggingface hub and deplying UI nodes
diff --git a/docs/source/en/modular_diffusers/pipeline_block.md b/docs/source/en/modular_diffusers/pipeline_block.md
new file mode 100644
index 0000000000..17a819732f
--- /dev/null
+++ b/docs/source/en/modular_diffusers/pipeline_block.md
@@ -0,0 +1,292 @@
+
+
+# PipelineBlock
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows.
+
+
+
+**Important**: `PipelineBlock`s are definitions/specifications, not runnable pipelines. They define what a block should do and what data it needs, but you need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](./modular_pipeline.md).
+
+
+
+In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with the pipeline state.
+
+## PipelineState
+
+Before we dive into creating `PipelineBlock`s, make sure you have a basic understanding of `PipelineState`. It acts as the global state container that all blocks operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes. See the [PipelineState and BlockState guide](./modular_diffusers_states.md) for more details.
+
+## Define a `PipelineBlock`
+
+To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce.
+
+The three main properties you need to define are:
+- `inputs`: Immutable values from the user that cannot be modified
+- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified
+- `intermediate_outputs`: New values your block creates for subsequent blocks and user access
+
+Let's explore each one and understand how they work with the pipeline state.
+
+**Inputs: Immutable User Values**
+
+Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`:
+
+```py
+user_inputs = [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
+]
+```
+
+When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged."
+
+This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency.
+
+**Intermediate Inputs: Mutable Values from Previous Blocks, or Users**
+
+Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case:
+
+```py
+user_intermediate_inputs = [
+ InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
+]
+```
+
+When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different."
+
+**Intermediate Outputs: New Values for Subsequent Blocks and User Access**
+
+Intermediate outputs are new variables your block creates and adds to the mutable pipeline state. They serve two purposes:
+
+1. **For subsequent blocks**: They can be used as intermediate inputs by other blocks in the pipeline
+2. **For users**: They become available as final outputs that users can access when running the pipeline
+
+```py
+user_intermediate_outputs = [
+ OutputParam(name="image_latents", description="latents representing the image")
+]
+```
+
+Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match.
+
+Additionally, all intermediate outputs are accessible to users when they run the pipeline, typically you would only need the final images, but they are also able to access intermediate results like latents, embeddings, or other processing steps.
+
+**The `__call__` Method Structure**
+
+Your `PipelineBlock`'s `__call__` method should follow this structure:
+
+```py
+def __call__(self, components, state):
+ # Get a local view of the state variables this block needs
+ block_state = self.get_block_state(state)
+
+ # Your computation logic here
+ # block_state contains all your inputs and intermediate_inputs
+ # You can access them like: block_state.image, block_state.processed_image
+
+ # Update the pipeline state with your updated block_states
+ self.set_block_state(state, block_state)
+ return components, state
+```
+
+The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation.
+
+**Components and Configs**
+
+You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`:
+
+```py
+from diffusers import ComponentSpec, ConfigSpec
+
+# Define components your block needs
+expected_components = [
+ ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
+ ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
+]
+
+# Define pipeline-level configs
+expected_config = [
+ ConfigSpec("force_zeros_for_empty_prompt", True)
+]
+```
+
+**Components**: In the `ComponentSpec`, you must provide a `name` and ideally a `type_hint`. You can also specify a `default_creation_method` to indicate whether the component should be loaded from a pretrained model or created with default configurations. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Modular Pipeline Guide](./modular_pipeline.md).
+
+**Configs**: Pipeline-level settings that control behavior across all blocks.
+
+When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block as the first argument of the `__call__` method. You can access any component you need using dot notation:
+
+```py
+def __call__(self, components, state):
+ # Access components using dot notation
+ unet = components.unet
+ vae = components.vae
+ scheduler = components.scheduler
+```
+
+That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks
+
+Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that.
+
+**Helper Function**
+
+```py
+from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
+import torch
+
+def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
+ class TestBlock(PipelineBlock):
+ model_name = "test"
+
+ @property
+ def inputs(self):
+ return inputs
+
+ @property
+ def intermediate_inputs(self):
+ return intermediate_inputs
+
+ @property
+ def intermediate_outputs(self):
+ return intermediate_outputs
+
+ @property
+ def description(self):
+ return description if description is not None else ""
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ if block_fn is not None:
+ block_state = block_fn(block_state, state)
+ self.set_block_state(state, block_state)
+ return components, state
+
+ return TestBlock
+```
+
+## Example: Creating a Simple Pipeline Block
+
+Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them:
+
+```py
+inputs = [
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
+]
+
+intermediate_inputs = [InputParam(name="batch_size", type_hint=int)]
+
+intermediate_outputs = [
+ OutputParam(name="image_latents", description="latents representing the image")
+]
+
+def image_encoder_block_fn(block_state, pipeline_state):
+ print(f"pipeline_state (before update): {pipeline_state}")
+ print(f"block_state (before update): {block_state}")
+
+ # Simulate processing the image
+ block_state.image = torch.randn(1, 3, 512, 512)
+ block_state.batch_size = block_state.batch_size * 2
+ block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size
+ block_state.image_latents = torch.randn(1, 4, 64, 64)
+
+ print(f"block_state (after update): {block_state}")
+ return block_state
+
+# Create a block with our definitions
+image_encoder_block_cls = make_block(
+ inputs=inputs,
+ intermediate_inputs=intermediate_inputs,
+ intermediate_outputs=intermediate_outputs,
+ block_fn=image_encoder_block_fn,
+ description="Encode raw image into its latent presentation"
+)
+image_encoder_block = image_encoder_block_cls()
+pipe = image_encoder_block.init_pipeline()
+```
+
+Let's check the pipeline's docstring to see what inputs it expects:
+```py
+>>> print(pipe.doc)
+class TestBlock
+
+ Encode raw image into its latent presentation
+
+ Inputs:
+
+ image (`PIL.Image`, *optional*):
+ raw input image to process
+
+ batch_size (`int`, *optional*):
+
+ Outputs:
+
+ image_latents (`None`):
+ latents representing the image
+```
+
+Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state.
+
+Now let's run the pipeline:
+
+```py
+from diffusers.utils import load_image
+
+image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png")
+state = pipe(image=image, batch_size=2)
+print(f"pipeline_state (after update): {state}")
+```
+```out
+pipeline_state (before update): PipelineState(
+ inputs={
+ image:
+ },
+ intermediates={
+ batch_size: 2
+ },
+)
+block_state (before update): BlockState(
+ image:
+ batch_size: 2
+)
+
+block_state (after update): BlockState(
+ image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512]))
+ batch_size: 4
+ processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])]
+ image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
+)
+pipeline_state (after update): PipelineState(
+ inputs={
+ image:
+ },
+ intermediates={
+ batch_size: 4
+ image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
+ },
+)
+```
+
+**Key Observations:**
+
+1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`.
+
+2. **After the update**:
+ - **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only.
+ - **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict)
+ - **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output
+ - **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output
\ No newline at end of file
diff --git a/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
new file mode 100644
index 0000000000..a683f0d065
--- /dev/null
+++ b/docs/source/en/modular_diffusers/sequential_pipeline_blocks.md
@@ -0,0 +1,189 @@
+
+
+# SequentialPipelineBlocks
+
+
+
+🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
+
+
+
+`SequentialPipelineBlocks` is a subclass of `ModularPipelineBlocks`. Unlike `PipelineBlock`, it is a multi-block that composes other blocks together in sequence, creating modular workflows where data flows from one block to the next. It's one of the most common ways to build complex pipelines by combining simpler building blocks.
+
+
+
+Other types of multi-blocks include [AutoPipelineBlocks](auto_pipeline_blocks.md) (for conditional block selection) and [LoopSequentialPipelineBlocks](loop_sequential_pipeline_blocks.md) (for iterative workflows). For information on creating individual blocks, see the [PipelineBlock guide](pipeline_block.md).
+
+Additionally, like all `ModularPipelineBlocks`, `SequentialPipelineBlocks` are definitions/specifications, not runnable pipelines. You need to convert them into a `ModularPipeline` to actually execute them. For information on creating and running pipelines, see the [Modular Pipeline guide](modular_pipeline.md).
+
+
+
+In this tutorial, we will focus on how to create `SequentialPipelineBlocks` and how blocks connect and work together.
+
+The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed in the [PipelineBlock guide](pipeline_block.md). When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks.
+
+Let's explore this through an example. We will use the same helper function from the PipelineBlock guide to create blocks.
+
+```py
+from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
+import torch
+
+def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
+ class TestBlock(PipelineBlock):
+ model_name = "test"
+
+ @property
+ def inputs(self):
+ return inputs
+
+ @property
+ def intermediate_inputs(self):
+ return intermediate_inputs
+
+ @property
+ def intermediate_outputs(self):
+ return intermediate_outputs
+
+ @property
+ def description(self):
+ return description if description is not None else ""
+
+ def __call__(self, components, state):
+ block_state = self.get_block_state(state)
+ if block_fn is not None:
+ block_state = block_fn(block_state, state)
+ self.set_block_state(state, block_state)
+ return components, state
+
+ return TestBlock
+```
+
+Let's create a block that produces `batch_size`, which we'll call "input_block":
+
+```py
+def input_block_fn(block_state, pipeline_state):
+
+ batch_size = len(block_state.prompt)
+ block_state.batch_size = batch_size * block_state.num_images_per_prompt
+
+ return block_state
+
+input_block_cls = make_block(
+ inputs=[
+ InputParam(name="prompt", type_hint=list, description="list of text prompts"),
+ InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt")
+ ],
+ intermediate_outputs=[
+ OutputParam(name="batch_size", description="calculated batch size")
+ ],
+ block_fn=input_block_fn,
+ description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
+)
+input_block = input_block_cls()
+```
+
+Now let's create a second block that uses the `batch_size` from the first block:
+
+```py
+def image_encoder_block_fn(block_state, pipeline_state):
+ # Simulate processing the image
+ block_state.image = torch.randn(1, 3, 512, 512)
+ block_state.batch_size = block_state.batch_size * 2
+ block_state.image_latents = torch.randn(1, 4, 64, 64)
+ return block_state
+
+image_encoder_block_cls = make_block(
+ inputs=[
+ InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
+ ],
+ intermediate_inputs=[
+ InputParam(name="batch_size", type_hint=int)
+ ],
+ intermediate_outputs=[
+ OutputParam(name="image_latents", description="latents representing the image")
+ ],
+ block_fn=image_encoder_block_fn,
+ description="Encode raw image into its latent presentation"
+)
+image_encoder_block = image_encoder_block_cls()
+```
+
+Now let's connect these blocks to create a `SequentialPipelineBlocks`:
+
+```py
+from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
+
+# Define a dict mapping block names to block instances
+blocks_dict = InsertableDict()
+blocks_dict["input"] = input_block
+blocks_dict["image_encoder"] = image_encoder_block
+
+# Create the SequentialPipelineBlocks
+blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
+```
+
+Now you have a `SequentialPipelineBlocks` with 2 blocks:
+
+```py
+>>> blocks
+SequentialPipelineBlocks(
+ Class: ModularPipelineBlocks
+
+ Description:
+
+
+ Sub-Blocks:
+ [0] input (TestBlock)
+ Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.
+
+ [1] image_encoder (TestBlock)
+ Description: Encode raw image into its latent presentation
+
+)
+```
+
+When you inspect `blocks.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it.
+
+```py
+>>> print(blocks.doc)
+class SequentialPipelineBlocks
+
+ Inputs:
+
+ prompt (`None`, *optional*):
+
+ num_images_per_prompt (`None`, *optional*):
+
+ image (`PIL.Image`, *optional*):
+ raw input image to process
+
+ Outputs:
+
+ batch_size (`None`):
+
+ image_latents (`None`):
+ latents representing the image
+```
+
+At runtime, you have data flow like this:
+
+
+
+**How SequentialPipelineBlocks Works:**
+
+1. Blocks are executed in the order they're registered in the `blocks_dict`
+2. Outputs from one block become available as intermediate inputs to all subsequent blocks
+3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks
+4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces
+
+What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs.
\ No newline at end of file
diff --git a/docs/source/en/optimization/fp16.md b/docs/source/en/optimization/fp16.md
index 2e12bfadcf..e32cbec917 100644
--- a/docs/source/en/optimization/fp16.md
+++ b/docs/source/en/optimization/fp16.md
@@ -150,11 +150,60 @@ pipeline(prompt, num_inference_steps=30).images[0]
Compilation is slow the first time, but once compiled, it is significantly faster. Try to only use the compiled pipeline on the same type of inference operations. Calling the compiled pipeline on a different image size retriggers compilation which is slow and inefficient.
+### Dynamic shape compilation
+
+> [!TIP]
+> Make sure to always use the nightly version of PyTorch for better support.
+
+`torch.compile` keeps track of input shapes and conditions, and if these are different, it recompiles the model. For example, if a model is compiled on a 1024x1024 resolution image and used on an image with a different resolution, it triggers recompilation.
+
+To avoid recompilation, add `dynamic=True` to try and generate a more dynamic kernel to avoid recompilation when conditions change.
+
+```diff
++ torch.fx.experimental._config.use_duck_shape = False
++ pipeline.unet = torch.compile(
+ pipeline.unet, fullgraph=True, dynamic=True
+)
+```
+
+Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic variable to represent input sizes that are the same. For more details, check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+
+Not all models may benefit from dynamic compilation out of the box and may require changes. Refer to this [PR](https://github.com/huggingface/diffusers/pull/11297/) that improved the [`AuraFlowPipeline`] implementation to benefit from dynamic compilation.
+
+Feel free to open an issue if dynamic compilation doesn't work as expected for a Diffusers model.
+
### Regional compilation
-[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) reduces the cold start compilation time by only compiling a specific repeated region (or block) of the model instead of the entire model. The compiler reuses the cached and compiled code for the other blocks.
+[Regional compilation](https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) trims cold-start latency by only compiling the *small and frequently-repeated block(s)* of a model - typically a transformer layer - and enables reusing compiled artifacts for every subsequent occurrence.
+For many diffusion architectures, this delivers the same runtime speedups as full-graph compilation and reduces compile time by 8–10x.
-[Accelerate](https://huggingface.co/docs/accelerate/index) provides the [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method for automatically compiling the repeated blocks of a `nn.Module` sequentially. The rest of the model is compiled separately.
+Use the [`~ModelMixin.compile_repeated_blocks`] method, a helper that wraps `torch.compile`, on any component such as the transformer model as shown below.
+
+```py
+# pip install -U diffusers
+import torch
+from diffusers import StableDiffusionXLPipeline
+
+pipeline = StableDiffusionXLPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16,
+).to("cuda")
+
+# compile only the repeated transformer layers inside the UNet
+pipeline.unet.compile_repeated_blocks(fullgraph=True)
+```
+
+To enable regional compilation for a new model, add a `_repeated_blocks` attribute to a model class containing the class names (as strings) of the blocks you want to compile.
+
+```py
+class MyUNet(ModelMixin):
+ _repeated_blocks = ("Transformer2DModel",) # ← compiled by default
+```
+
+> [!TIP]
+> For more regional compilation examples, see the reference [PR](https://github.com/huggingface/diffusers/pull/11705).
+
+There is also a [compile_regions](https://github.com/huggingface/accelerate/blob/273799c85d849a1954a4f2e65767216eb37fa089/src/accelerate/utils/other.py#L78) method in [Accelerate](https://huggingface.co/docs/accelerate/index) that automatically selects candidate blocks in a model to compile. The remaining graph is compiled separately. This is useful for quick experiments because there aren't as many options for you to set which blocks to compile or adjust compilation flags.
```py
# pip install -U accelerate
@@ -168,6 +217,8 @@ pipeline = StableDiffusionXLPipeline.from_pretrained(
pipeline.unet = compile_regions(pipeline.unet, mode="reduce-overhead", fullgraph=True)
```
+[`~ModelMixin.compile_repeated_blocks`] is intentionally explicit. List the blocks to repeat in `_repeated_blocks` and the helper only compiles those blocks. It offers predictable behavior and easy reasoning about cache reuse in one line of code.
+
### Graph breaks
It is important to specify `fullgraph=True` in torch.compile to ensure there are no graph breaks in the underlying model. This allows you to take advantage of torch.compile without any performance degradation. For the UNet and VAE, this changes how you access the return variables.
@@ -188,6 +239,12 @@ The `step()` function is [called](https://github.com/huggingface/diffusers/blob/
In general, the `sigmas` should [stay on the CPU](https://github.com/huggingface/diffusers/blob/35a969d297cba69110d175ee79c59312b9f49e1e/src/diffusers/schedulers/scheduling_euler_discrete.py#L240) to avoid the communication sync and latency.
+
+
+Refer to the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post for maximizing performance with `torch.compile` for diffusion models.
+
+
+
### Benchmarks
Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/benchmarks) dataset to see inference latency and memory usage data for compiled pipelines.
@@ -241,4 +298,12 @@ An input is projected into three subspaces, represented by the projection matric
```py
pipeline.fuse_qkv_projections()
-```
\ No newline at end of file
+```
+
+## Resources
+
+- Read the [Presenting Flux Fast: Making Flux go brrr on H100s](https://pytorch.org/blog/presenting-flux-fast-making-flux-go-brrr-on-h100s/) blog post to learn more about how you can combine all of these optimizations with [TorchInductor](https://docs.pytorch.org/docs/stable/torch.compiler.html) and [AOTInductor](https://docs.pytorch.org/docs/stable/torch.compiler_aot_inductor.html) for a ~2.5x speedup using recipes from [flux-fast](https://github.com/huggingface/flux-fast).
+
+ These recipes support AMD hardware and [Flux.1 Kontext Dev](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev).
+- Read the [torch.compile and Diffusers: A Hands-On Guide to Peak Performance](https://pytorch.org/blog/torch-compile-and-diffusers-a-hands-on-guide-to-peak-performance/) blog post
+to maximize performance when using `torch.compile`.
\ No newline at end of file
diff --git a/docs/source/en/optimization/speed-memory-optims.md b/docs/source/en/optimization/speed-memory-optims.md
index 4a76d272cf..f43e60bc74 100644
--- a/docs/source/en/optimization/speed-memory-optims.md
+++ b/docs/source/en/optimization/speed-memory-optims.md
@@ -14,6 +14,9 @@ specific language governing permissions and limitations under the License.
Optimizing models often involves trade-offs between [inference speed](./fp16) and [memory-usage](./memory). For instance, while [caching](./cache) can boost inference speed, it also increases memory consumption since it needs to store the outputs of intermediate attention layers. A more balanced optimization strategy combines quantizing a model, [torch.compile](./fp16#torchcompile) and various [offloading methods](./memory#offloading).
+> [!TIP]
+> Check the [torch.compile](./fp16#torchcompile) guide to learn more about compilation and how they can be applied here. For example, regional compilation can significantly reduce compilation time without giving up any speedups.
+
For image generation, combining quantization and [model offloading](./memory#model-offloading) can often give the best trade-off between quality, speed, and memory. Group offloading is not as effective for image generation because it is usually not possible to *fully* overlap data transfer if the compute kernel finishes faster. This results in some communication overhead between the CPU and GPU.
For video generation, combining quantization and [group-offloading](./memory#group-offloading) tends to be better because video models are more compute-bound.
@@ -25,7 +28,7 @@ The table below provides a comparison of optimization strategy combinations and
| quantization | 32.602 | 14.9453 |
| quantization, torch.compile | 25.847 | 14.9448 |
| quantization, torch.compile, model CPU offloading | 32.312 | 12.2369 |
-These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the if you're interested in evaluating your own model.
+These results are benchmarked on Flux with a RTX 4090. The transformer and text_encoder components are quantized. Refer to the [benchmarking script](https://gist.github.com/sayakpaul/0db9d8eeeb3d2a0e5ed7cf0d9ca19b7d) if you're interested in evaluating your own model.
This guide will show you how to compile and offload a quantized model with [bitsandbytes](../quantization/bitsandbytes#torchcompile). Make sure you are using [PyTorch nightly](https://pytorch.org/get-started/locally/) and the latest version of bitsandbytes.
diff --git a/docs/source/en/quantization/overview.md b/docs/source/en/quantization/overview.md
index da11f57ec0..12c39f52e4 100644
--- a/docs/source/en/quantization/overview.md
+++ b/docs/source/en/quantization/overview.md
@@ -11,7 +11,7 @@ specific language governing permissions and limitations under the License.
-->
-# Quantization
+# Getting started
Quantization focuses on representing data with fewer bits while also trying to preserve the precision of the original data. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
@@ -19,19 +19,25 @@ Diffusers supports multiple quantization backends to make large diffusion models
## Pipeline-level quantization
-There are two ways you can use [`~quantizers.PipelineQuantizationConfig`] depending on the level of control you want over the quantization specifications of each model in the pipeline.
+There are two ways to use [`~quantizers.PipelineQuantizationConfig`] depending on how much customization you want to apply to the quantization configuration.
-- for more basic and simple use cases, you only need to define the `quant_backend`, `quant_kwargs`, and `components_to_quantize`
-- for more granular quantization control, provide a `quant_mapping` that provides the quantization specifications for the individual model components
+- for basic use cases, define the `quant_backend`, `quant_kwargs`, and `components_to_quantize` arguments
+- for granular quantization control, define a `quant_mapping` that provides the quantization configuration for individual model components
-### Simple quantization
+### Basic quantization
Initialize [`~quantizers.PipelineQuantizationConfig`] with the following parameters.
- `quant_backend` specifies which quantization backend to use. Currently supported backends include: `bitsandbytes_4bit`, `bitsandbytes_8bit`, `gguf`, `quanto`, and `torchao`.
-- `quant_kwargs` contains the specific quantization arguments to use.
+- `quant_kwargs` specifies the quantization arguments to use.
+
+> [!TIP]
+> These `quant_kwargs` arguments are different for each backend. Refer to the [Quantization API](../api/quantization) docs to view the arguments for each backend.
+
- `components_to_quantize` specifies which components of the pipeline to quantize. Typically, you should quantize the most compute intensive components like the transformer. The text encoder is another component to consider quantizing if a pipeline has more than one such as [`FluxPipeline`]. The example below quantizes the T5 text encoder in [`FluxPipeline`] while keeping the CLIP model intact.
+The example below loads the bitsandbytes backend with the following arguments from [`~quantizers.quantization_config.BitsAndBytesConfig`], `load_in_4bit`, `bnb_4bit_quant_type`, and `bnb_4bit_compute_dtype`.
+
```py
import torch
from diffusers import DiffusionPipeline
@@ -56,13 +62,13 @@ pipe = DiffusionPipeline.from_pretrained(
image = pipe("photo of a cute dog").images[0]
```
-### quant_mapping
+### Advanced quantization
-The `quant_mapping` argument provides more flexible options for how to quantize each individual component in a pipeline, like combining different quantization backends.
+The `quant_mapping` argument provides more options for how to quantize each individual component in a pipeline, like combining different quantization backends.
Initialize [`~quantizers.PipelineQuantizationConfig`] and pass a `quant_mapping` to it. The `quant_mapping` allows you to specify the quantization options for each component in the pipeline such as the transformer and text encoder.
-The example below uses two quantization backends, [`~quantizers.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
+The example below uses two quantization backends, [`~quantizers.quantization_config.QuantoConfig`] and [`transformers.BitsAndBytesConfig`], for the transformer and text encoder.
```py
import torch
@@ -85,7 +91,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
There is a separate bitsandbytes backend in [Transformers](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig). You need to import and use [`transformers.BitsAndBytesConfig`] for components that come from Transformers. For example, `text_encoder_2` in [`FluxPipeline`] is a [`~transformers.T5EncoderModel`] from Transformers so you need to use [`transformers.BitsAndBytesConfig`] instead of [`diffusers.BitsAndBytesConfig`].
> [!TIP]
-> Use the [simple quantization](#simple-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
+> Use the [basic quantization](#basic-quantization) method above if you don't want to manage these distinct imports or aren't sure where each pipeline component comes from.
```py
import torch
@@ -129,4 +135,4 @@ Check out the resources below to learn more about quantization.
- The Transformers quantization [Overview](https://huggingface.co/docs/transformers/quantization/overview#when-to-use-what) provides an overview of the pros and cons of different quantization backends.
-- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.
\ No newline at end of file
+- Read the [Exploring Quantization Backends in Diffusers](https://huggingface.co/blog/diffusers-quantization) blog post for a brief introduction to each quantization backend, how to choose a backend, and combining quantization with other memory optimizations.
diff --git a/docs/source/en/training/cogvideox.md b/docs/source/en/training/cogvideox.md
index f277d56136..d0700c4da7 100644
--- a/docs/source/en/training/cogvideox.md
+++ b/docs/source/en/training/cogvideox.md
@@ -145,10 +145,10 @@ When running `accelerate config`, if you use torch.compile, there can be dramati
If you would like to push your model to the Hub after training is completed with a neat model card, make sure you're logged in:
```bash
-huggingface-cli login
+hf auth login
# Alternatively, you could upload your model manually using:
-# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
+# hf upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
```
Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!
diff --git a/docs/source/en/training/create_dataset.md b/docs/source/en/training/create_dataset.md
index f3221beb40..8e0d6f9200 100644
--- a/docs/source/en/training/create_dataset.md
+++ b/docs/source/en/training/create_dataset.md
@@ -67,7 +67,7 @@ dataset = load_dataset(
Then use the [`~datasets.Dataset.push_to_hub`] method to upload the dataset to the Hub:
```python
-# assuming you have ran the huggingface-cli login command in a terminal
+# assuming you have ran the hf auth login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
diff --git a/docs/source/en/tutorials/basic_training.md b/docs/source/en/tutorials/basic_training.md
index 1ed81dd672..9a35b3438f 100644
--- a/docs/source/en/tutorials/basic_training.md
+++ b/docs/source/en/tutorials/basic_training.md
@@ -42,7 +42,7 @@ We encourage you to share your model with the community, and in order to do that
Or login in from the terminal:
```bash
-huggingface-cli login
+hf auth login
```
Since the model checkpoints are quite large, install [Git-LFS](https://git-lfs.com/) to version these large files:
diff --git a/docs/source/en/tutorials/tutorial_overview.md b/docs/source/en/tutorials/tutorial_overview.md
deleted file mode 100644
index e8700d82c0..0000000000
--- a/docs/source/en/tutorials/tutorial_overview.md
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-# Overview
-
-Welcome to 🧨 Diffusers! If you're new to diffusion models and generative AI, and want to learn more, then you've come to the right place. These beginner-friendly tutorials are designed to provide a gentle introduction to diffusion models and help you understand the library fundamentals - the core components and how 🧨 Diffusers is meant to be used.
-
-You'll learn how to use a pipeline for inference to rapidly generate things, and then deconstruct that pipeline to really understand how to use the library as a modular toolbox for building your own diffusion systems. In the next lesson, you'll learn how to train your own diffusion model to generate what you want.
-
-After completing the tutorials, you'll have gained the necessary skills to start exploring the library on your own and see how to use it for your own projects and applications.
-
-Feel free to join our community on [Discord](https://discord.com/invite/JfAtkvEtRb) or the [forums](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) to connect and collaborate with other users and developers!
-
-Let's start diffusing! 🧨
diff --git a/docs/source/en/tutorials/using_peft_for_inference.md b/docs/source/en/tutorials/using_peft_for_inference.md
index b18977720c..5cd47f8674 100644
--- a/docs/source/en/tutorials/using_peft_for_inference.md
+++ b/docs/source/en/tutorials/using_peft_for_inference.md
@@ -315,8 +315,23 @@ pipeline.load_lora_weights(
> [!TIP]
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
+If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
+
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
+
+Technical details of hotswapping
+
+The [`~loaders.lora_base.LoraBaseMixin.enable_lora_hotswap`] method converts the LoRA scaling factor from floats to torch.tensors and pads the shape of the weights to the largest required shape to avoid reassigning the whole attribute when the data in the weights are replaced.
+
+This is why the `max_rank` argument is important. The results are unchanged even when the values are padded with zeros. Computation may be slower though depending on the padding size.
+
+Since no new LoRA attributes are added, each subsequent LoRA is only allowed to target the same layers, or subset of layers, the first LoRA targets. Choosing the LoRA loading order is important because if the LoRAs target disjoint layers, you may end up creating a dummy LoRA that targets the union of all target layers.
+
+For more implementation details, take a look at the [`hotswap.py`](https://github.com/huggingface/peft/blob/92d65cafa51c829484ad3d95cf71d09de57ff066/src/peft/utils/hotswap.py) file.
+
+
+
## Merge
The weights from each LoRA can be merged together to produce a blend of multiple existing styles. There are several methods for merging LoRAs, each of which differ in *how* the weights are merged (may affect generation quality).
@@ -671,4 +686,6 @@ Browse the [LoRA Studio](https://lorastudio.co/models) for different LoRAs to us
height="450"
>
-You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
\ No newline at end of file
+You can find additional LoRAs in the [FLUX LoRA the Explorer](https://huggingface.co/spaces/multimodalart/flux-lora-the-explorer) and [LoRA the Explorer](https://huggingface.co/spaces/multimodalart/LoraTheExplorer) Spaces.
+
+Check out the [Fast LoRA inference for Flux with Diffusers and PEFT](https://huggingface.co/blog/lora-fast) blog post to learn how to optimize LoRA inference with methods like FlashAttention-3 and fp8 quantization.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/batched_inference.md b/docs/source/en/using-diffusers/batched_inference.md
new file mode 100644
index 0000000000..b5e55c27ca
--- /dev/null
+++ b/docs/source/en/using-diffusers/batched_inference.md
@@ -0,0 +1,264 @@
+
+
+# Batch inference
+
+Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
+
+The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
+
+
+
+
+For text-to-image, pass a list of prompts to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+images = pipeline(
+ prompt=prompts,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+For image-to-image, pass a list of input images and prompts to the pipeline.
+
+```py
+import torch
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ guidance_scale=8.0,
+ strength=0.5
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
+
+```py
+import torch
+import matplotlib.pyplot as plt
+from diffusers.utils import load_image
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+
+images = pipeline(
+ prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
+ image=input_image,
+ num_images_per_prompt=4
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+Combine both approaches to generate different variations of different prompts.
+
+```py
+input_images = [
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
+ load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
+]
+
+prompts = [
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ image=input_images,
+ num_images_per_prompt=2,
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+
+
+
+## Deterministic generation
+
+Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
+
+Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
+
+Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
+
+```py
+generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
+```
+
+Pass the `generator` to the pipeline.
+
+```py
+import torch
+from diffusers import DiffusionPipeline
+
+pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0",
+ torch_dtype=torch.float16
+).to("cuda")
+
+generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
+prompts = [
+ "cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
+ "cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
+ "pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
+]
+
+images = pipeline(
+ prompt=prompts,
+ generator=generator
+).images
+
+fig, axes = plt.subplots(2, 2, figsize=(12, 12))
+axes = axes.flatten()
+
+for i, image in enumerate(images):
+ axes[i].imshow(image)
+ axes[i].set_title(f"Image {i+1}")
+ axes[i].axis('off')
+
+plt.tight_layout()
+plt.show()
+```
+
+You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
\ No newline at end of file
diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md
index df3df92f06..11afbf29d3 100644
--- a/docs/source/en/using-diffusers/other-formats.md
+++ b/docs/source/en/using-diffusers/other-formats.md
@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
-#### LoRA files
+#### LoRAs
-[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
+[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
-LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
-
-```py
-from diffusers import StableDiffusionXLPipeline
-import torch
-
-# base model
-pipeline = StableDiffusionXLPipeline.from_pretrained(
- "Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
-).to("cuda")
-
-# download LoRA weights
-!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
-
-# load LoRA weights
-pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
-prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
-negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
-
-image = pipeline(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.manual_seed(0),
-).images[0]
-image
-```
+The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
-

+
+For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
+
+```py
+import torch
+from diffusers import FluxPipeline
+
+pipeline = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
+).to("cuda")
+pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
+pipeline.save_lora_weights(
+ transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
+ text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
+)
+```
+
### ckpt
> [!WARNING]
diff --git a/docs/source/en/using-diffusers/overview_techniques.md b/docs/source/en/using-diffusers/overview_techniques.md
deleted file mode 100644
index a0b37cc52f..0000000000
--- a/docs/source/en/using-diffusers/overview_techniques.md
+++ /dev/null
@@ -1,18 +0,0 @@
-
-
-# Overview
-
-The inference pipeline supports and enables a wide range of techniques that are divided into two categories:
-
-* Pipeline functionality: these techniques modify the pipeline or extend it for other applications. For example, pipeline callbacks add new features to a pipeline and a pipeline can also be extended for distributed inference.
-* Improve inference quality: these techniques increase the visual quality of the generated images. For example, you can enhance your prompts with GPT2 to create better images with lower effort.
diff --git a/docs/source/en/using-diffusers/reusing_seeds.md b/docs/source/en/using-diffusers/reusing_seeds.md
index 60b8fee754..ac9350f24c 100644
--- a/docs/source/en/using-diffusers/reusing_seeds.md
+++ b/docs/source/en/using-diffusers/reusing_seeds.md
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
print("L_inf dist =", abs(result1 - result2).max())
"L_inf dist = tensor(0., device='cuda:0')"
```
-
-## Deterministic batch generation
-
-A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
-
-Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
-
-```py
-import torch
-from diffusers import DiffusionPipeline
-from diffusers.utils import make_image_grid
-
-pipeline = DiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
-)
-pipeline = pipeline.to("cuda")
-```
-
-Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
-
-> [!WARNING]
-> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
->
-> ```py
-> [torch.Generator().manual_seed(seed)] * 4
-> ```
-
-```python
-generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
-prompt = "Labrador in the style of Vermeer"
-images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
-
-Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
-
-```python
-prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
-generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
-images = pipeline(prompt, generator=generator).images
-make_image_grid(images, rows=2, cols=2)
-```
-
-
-

-
diff --git a/docs/source/en/using-diffusers/schedulers.md b/docs/source/en/using-diffusers/schedulers.md
index a3efbf2e80..aabb9dd31c 100644
--- a/docs/source/en/using-diffusers/schedulers.md
+++ b/docs/source/en/using-diffusers/schedulers.md
@@ -242,3 +242,15 @@ unet = UNet2DConditionModel.from_pretrained(
)
unet.save_pretrained("./local-unet", variant="non_ema")
```
+
+Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
+
+```py
+from diffusers import AutoModel
+
+unet = AutoModel.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
+)
+```
+
+You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
diff --git a/docs/source/ko/optimization/mps.md b/docs/source/ko/optimization/mps.md
index 218c4790a5..4daeaf5dba 100644
--- a/docs/source/ko/optimization/mps.md
+++ b/docs/source/ko/optimization/mps.md
@@ -37,7 +37,7 @@ Diffusers는 Stable Diffusion 추론을 위해 PyTorch `mps`를 사용해 Apple
```python
-# `huggingface-cli login`에 로그인되어 있음을 확인
+# `hf auth login`에 로그인되어 있음을 확인
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
diff --git a/docs/source/ko/training/create_dataset.md b/docs/source/ko/training/create_dataset.md
index 401a73ebf2..a869cd09f0 100644
--- a/docs/source/ko/training/create_dataset.md
+++ b/docs/source/ko/training/create_dataset.md
@@ -75,7 +75,7 @@ dataset = load_dataset(
[push_to_hub(https://huggingface.co/docs/datasets/v2.13.1/en/package_reference/main_classes#datasets.Dataset.push_to_hub) 을 사용해서 Hub에 데이터셋을 업로드 합니다:
```python
-# 터미널에서 huggingface-cli login 커맨드를 이미 실행했다고 가정합니다
+# 터미널에서 hf auth login 커맨드를 이미 실행했다고 가정합니다
dataset.push_to_hub("name_of_your_dataset")
# 개인 repo로 push 하고 싶다면, `private=True` 을 추가하세요:
diff --git a/docs/source/ko/training/lora.md b/docs/source/ko/training/lora.md
index 41ea8dbd46..5bcef27143 100644
--- a/docs/source/ko/training/lora.md
+++ b/docs/source/ko/training/lora.md
@@ -39,7 +39,7 @@ specific language governing permissions and limitations under the License.
모델을 저장하거나 커뮤니티와 공유하려면 Hugging Face 계정에 로그인하세요(아직 계정이 없는 경우 [생성](https://huggingface.co/join)하세요):
```bash
-huggingface-cli login
+hf auth login
```
## Text-to-image
diff --git a/docs/source/ko/tutorials/basic_training.md b/docs/source/ko/tutorials/basic_training.md
index bb49771052..2c4c89edd1 100644
--- a/docs/source/ko/tutorials/basic_training.md
+++ b/docs/source/ko/tutorials/basic_training.md
@@ -42,7 +42,7 @@ Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한
또는 터미널로 로그인할 수 있습니다:
```bash
-huggingface-cli login
+hf auth login
```
모델 체크포인트가 상당히 크기 때문에 [Git-LFS](https://git-lfs.com/)에서 대용량 파일의 버전 관리를 할 수 있습니다.
diff --git a/docs/source/ko/using-diffusers/other-formats.md b/docs/source/ko/using-diffusers/other-formats.md
index 95b2485f61..3034551f48 100644
--- a/docs/source/ko/using-diffusers/other-formats.md
+++ b/docs/source/ko/using-diffusers/other-formats.md
@@ -42,7 +42,7 @@ Stable Diffusion 모델들은 학습 및 저장된 프레임워크와 다운로
시작하기 전에 스크립트를 실행할 🤗 Diffusers의 로컬 클론(clone)이 있는지 확인하고 Hugging Face 계정에 로그인하여 pull request를 열고 변환된 모델을 허브에 푸시할 수 있도록 하세요.
```bash
-huggingface-cli login
+hf auth login
```
스크립트를 사용하려면:
diff --git a/examples/advanced_diffusion_training/README.md b/examples/advanced_diffusion_training/README.md
index eedb1c96e4..c9c3c1c508 100644
--- a/examples/advanced_diffusion_training/README.md
+++ b/examples/advanced_diffusion_training/README.md
@@ -69,7 +69,7 @@ Note also that we use PEFT library as backend for LoRA training, make sure to ha
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
-huggingface-cli login
+hf auth login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
diff --git a/examples/advanced_diffusion_training/README_flux.md b/examples/advanced_diffusion_training/README_flux.md
index 62f9078949..65e59ba6e7 100644
--- a/examples/advanced_diffusion_training/README_flux.md
+++ b/examples/advanced_diffusion_training/README_flux.md
@@ -67,7 +67,7 @@ Note also that we use PEFT library as backend for LoRA training, make sure to ha
Lastly, we recommend logging into your HF account so that your trained LoRA is automatically uploaded to the hub:
```bash
-huggingface-cli login
+hf auth login
```
This command will prompt you for a token. Copy-paste yours from your [settings/tokens](https://huggingface.co/settings/tokens),and press Enter.
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
index 173d3bfd5b..a30624e35a 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import copy
import itertools
@@ -75,7 +89,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -971,6 +985,7 @@ class DreamBoothDataset(Dataset):
def __init__(
self,
+ args,
instance_data_root,
instance_prompt,
class_prompt,
@@ -980,10 +995,8 @@ class DreamBoothDataset(Dataset):
class_num=None,
size=1024,
repeats=1,
- center_crop=False,
):
self.size = size
- self.center_crop = center_crop
self.instance_prompt = instance_prompt
self.custom_instance_prompts = None
@@ -1058,7 +1071,7 @@ class DreamBoothDataset(Dataset):
if interpolation is None:
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
train_resize = transforms.Resize(size, interpolation=interpolation)
- train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
+ train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
train_flip = transforms.RandomHorizontalFlip(p=1.0)
train_transforms = transforms.Compose(
[
@@ -1075,11 +1088,11 @@ class DreamBoothDataset(Dataset):
# flip
image = train_flip(image)
if args.center_crop:
- y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
- x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
+ y1 = max(0, int(round((image.height - self.size) / 2.0)))
+ x1 = max(0, int(round((image.width - self.size) / 2.0)))
image = train_crop(image)
else:
- y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
+ y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
image = crop(image, y1, x1, h, w)
image = train_transforms(image)
self.pixel_values.append(image)
@@ -1102,7 +1115,7 @@ class DreamBoothDataset(Dataset):
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=interpolation),
- transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
+ transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -1322,7 +1335,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
@@ -1827,6 +1840,7 @@ def main(args):
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
+ args=args,
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,7 +1850,6 @@ def main(args):
class_num=args.num_class_images,
size=args.resolution,
repeats=args.repeats,
- center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
index 52aee07e81..17c5150eb1 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import gc
import hashlib
@@ -73,7 +87,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -1050,7 +1064,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
index 911102c049..65e280801c 100644
--- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
+++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import gc
import itertools
@@ -80,7 +94,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -1292,7 +1306,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
diff --git a/examples/cogvideo/README.md b/examples/cogvideo/README.md
index dc74690983..ab0facc0a1 100644
--- a/examples/cogvideo/README.md
+++ b/examples/cogvideo/README.md
@@ -125,10 +125,10 @@ When running `accelerate config`, if we specify torch compile mode to True there
If you would like to push your model to the HF Hub after training is completed with a neat model card, make sure you're logged in:
```
-huggingface-cli login
+hf auth login
# Alternatively, you could upload your model manually using:
-# huggingface-cli upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
+# hf upload my-cool-account-name/my-cool-lora-name /path/to/awesome/lora
```
Make sure your data is prepared as described in [Data Preparation](#data-preparation). When ready, you can begin training!
diff --git a/examples/cogvideo/train_cogvideox_image_to_video_lora.py b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
index 315c61c60b..1ebc58b494 100644
--- a/examples/cogvideo/train_cogvideox_image_to_video_lora.py
+++ b/examples/cogvideo/train_cogvideox_image_to_video_lora.py
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -962,7 +962,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py
index a8e73e938c..f6903fde0a 100644
--- a/examples/cogvideo/train_cogvideox_lora.py
+++ b/examples/cogvideo/train_cogvideox_lora.py
@@ -52,7 +52,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -984,7 +984,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/cogview4-control/README.md b/examples/cogview4-control/README.md
index 746a99a1a4..c73c5ed3ca 100644
--- a/examples/cogview4-control/README.md
+++ b/examples/cogview4-control/README.md
@@ -10,7 +10,7 @@ To incorporate additional condition latents, we expand the input features of Cog
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
diff --git a/examples/cogview4-control/train_control_cogview4.py b/examples/cogview4-control/train_control_cogview4.py
index 7d2ce20949..93b33a189e 100644
--- a/examples/cogview4-control/train_control_cogview4.py
+++ b/examples/cogview4-control/train_control_cogview4.py
@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -705,7 +705,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_out_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/community/README.md b/examples/community/README.md
index 225a25fac7..e4fbd79366 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -87,6 +87,7 @@ PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixar
| CogVideoX DDIM Inversion Pipeline | Implementation of DDIM inversion and guided attention-based editing denoising process on CogVideoX. | [CogVideoX DDIM Inversion Pipeline](#cogvideox-ddim-inversion-pipeline) | - | [LittleNyima](https://github.com/LittleNyima) |
| FaithDiff Stable Diffusion XL Pipeline | Implementation of [(CVPR 2025) FaithDiff: Unleashing Diffusion Priors for Faithful Image Super-resolutionUnleashing Diffusion Priors for Faithful Image Super-resolution](https://huggingface.co/papers/2411.18824) - FaithDiff is a faithful image super-resolution method that leverages latent diffusion models by actively adapting the diffusion prior and jointly fine-tuning its components (encoder and diffusion model) with an alignment module to ensure high fidelity and structural consistency. | [FaithDiff Stable Diffusion XL Pipeline](#faithdiff-stable-diffusion-xl-pipeline) | [](https://huggingface.co/jychen9811/FaithDiff) | [Junyang Chen, Jinshan Pan, Jiangxin Dong, IMAG Lab, (Adapted by Eliseu Silva)](https://github.com/JyChen9811/FaithDiff) |
| Stable Diffusion 3 InstructPix2Pix Pipeline | Implementation of Stable Diffusion 3 InstructPix2Pix Pipeline | [Stable Diffusion 3 InstructPix2Pix Pipeline](#stable-diffusion-3-instructpix2pix-pipeline) | [](https://huggingface.co/BleachNick/SD3_UltraEdit_freeform) [](https://huggingface.co/CaptainZZZ/sd3-instructpix2pix) | [Jiayu Zhang](https://github.com/xduzhangjiayu) and [Haozhe Zhao](https://github.com/HaozheZhao)|
+| Flux Kontext multiple images | A modified version of the `FluxKontextPipeline` that supports calling Flux Kontext with multiple reference images.| [Flux Kontext multiple input Pipeline](#flux-kontext-multiple-images) | - | [Net-Mist](https://github.com/Net-Mist) |
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
```py
@@ -3128,7 +3129,7 @@ from io import BytesIO
from diffusers import DiffusionPipeline
# load the pipeline
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
model_id_or_path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
# can also be used with dreamlike-art/dreamlike-photoreal-2.0
pipe = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16, custom_pipeline="pipeline_fabric").to("cuda")
@@ -5479,4 +5480,48 @@ edited_image.save("edited_image.png")
### Note
This model is trained on 512x512, so input size is better on 512x512.
For better editing performance, please refer to this powerful model https://huggingface.co/BleachNick/SD3_UltraEdit_freeform and Paper "UltraEdit: Instruction-based Fine-Grained Image
-Editing at Scale", many thanks to their contribution!
\ No newline at end of file
+Editing at Scale", many thanks to their contribution!
+
+# Flux Kontext multiple images
+
+This implementation of Flux Kontext allows users to pass multiple reference images. Each image is encoded separately, and the resulting latent vectors are concatenated.
+
+As explained in Section 3 of [the paper](https://arxiv.org/pdf/2506.15742), the model's sequence concatenation mechanism can extend its capabilities to handle multiple reference images. However, note that the current version of Flux Kontext was not trained for this use case. In practice, stacking along the first axis does not yield correct results, while stacking along the other two axes appears to work.
+
+## Example Usage
+
+This pipeline loads two reference images and generates a new image based on them.
+
+```python
+import torch
+
+from diffusers import FluxKontextPipeline
+from diffusers.utils import load_image
+
+
+pipe = FluxKontextPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-Kontext-dev",
+ torch_dtype=torch.bfloat16,
+ custom_pipeline="pipeline_flux_kontext_multiple_images",
+)
+pipe.to("cuda")
+
+pikachu_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+).convert("RGB")
+cat_image = load_image(
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
+).convert("RGB")
+
+
+prompts = [
+ "Pikachu and the cat are sitting together at a pizzeria table, enjoying a delicious pizza.",
+]
+images = pipe(
+ multiple_images=[(pikachu_image, cat_image)],
+ prompt=prompts,
+ guidance_scale=2.5,
+ generator=torch.Generator().manual_seed(42),
+).images
+images[0].save("pizzeria.png")
+```
diff --git a/examples/community/marigold_depth_estimation.py b/examples/community/marigold_depth_estimation.py
index 453735411d..8be773c138 100644
--- a/examples/community/marigold_depth_estimation.py
+++ b/examples/community/marigold_depth_estimation.py
@@ -43,7 +43,7 @@ from diffusers.utils import BaseOutput, check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
class MarigoldDepthOutput(BaseOutput):
diff --git a/examples/community/pipeline_flux_kontext_multiple_images.py b/examples/community/pipeline_flux_kontext_multiple_images.py
new file mode 100644
index 0000000000..ef0c643a40
--- /dev/null
+++ b/examples/community/pipeline_flux_kontext_multiple_images.py
@@ -0,0 +1,1211 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, FluxTransformer2DModel
+from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+PipelineSeveralImagesInput = Union[
+ Tuple[PIL.Image.Image, ...],
+ Tuple[np.ndarray, ...],
+ Tuple[torch.Tensor, ...],
+ List[Tuple[PIL.Image.Image, ...]],
+ List[Tuple[np.ndarray, ...]],
+ List[Tuple[torch.Tensor, ...]],
+]
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
+ >>> image = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... guidance_scale=2.5,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def preprocess_image(self, image: PipelineImageInput, _auto_resize: bool, multiple_of: int) -> torch.Tensor:
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+ image = self.image_processor.preprocess(image, image_height, image_width)
+ return image
+
+ def preprocess_images(
+ self,
+ images: PipelineSeveralImagesInput,
+ _auto_resize: bool,
+ multiple_of: int,
+ ) -> torch.Tensor:
+ # TODO for reviewer: I'm not sure what's the best way to implement this part given the philosophy of the repo.
+ # The solutions I thought about are:
+ # - Make the `resize` and `preprocess` methods of `VaeImageProcessor` more generic (using TypeVar for instance)
+ # - Start by converting the image to a List[Tuple[ {image_format} ]], to unify the processing logic
+ # - Or duplicate the code, as done here.
+ # What do you think ?
+
+ # convert multiple_images to a list of tuple, to simplify following logic
+ if not isinstance(images, list):
+ images = [images]
+ # now multiple_images is a list of tuples.
+
+ img = images[0][0]
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ n_image_per_batch = len(images[0])
+ output_images = []
+ for i in range(n_image_per_batch):
+ image = [batch_images[i] for batch_images in images]
+ image = self.image_processor.resize(image, image_height, image_width)
+ image = self.image_processor.preprocess(image, image_height, image_width)
+ output_images.append(image)
+ return output_images
+
+ def prepare_latents(
+ self,
+ images: Optional[list[torch.Tensor]],
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ all_image_latents = []
+ all_image_ids = []
+ image_latents = images_ids = None
+ if images is not None:
+ for i, image in enumerate(images):
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ # set the image ids to the correct position in the latent grid
+ image_ids[..., 2] += i * (image_latent_height // 2)
+
+ all_image_ids.append(image_ids)
+ all_image_latents.append(image_latents)
+
+ image_latents = torch.cat(all_image_latents, dim=1)
+ images_ids = torch.cat(all_image_ids, dim=0)
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents, latent_ids, images_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ multiple_images: Optional[PipelineSeveralImagesInput] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Guidance scale as defined in [Classifier-Free Diffusion
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
+ the text `prompt`, usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+ multiple_images (`PipelineSeveralImagesInput`, *optional*):
+ A list of images to be used as reference images for the generation. If provided, the pipeline will
+ merge the reference images in the latent space.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess image
+ if image is not None and multiple_images is not None:
+ raise ValueError("Cannot pass both `image` and `multiple_images`. Please use only one of them.")
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ image = [self.preprocess_image(image, _auto_resize=True, multiple_of=multiple_of)]
+ if multiple_images is not None:
+ image = self.preprocess_images(
+ multiple_images,
+ _auto_resize=_auto_resize,
+ multiple_of=multiple_of,
+ )
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
+ image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
index b254799756..5822967d05 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sd_wds.py
@@ -73,7 +73,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -877,7 +877,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
index 554319aef4..e7f64ef14d 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
@@ -66,7 +66,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -709,7 +709,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
index 52d4806100..4b79a59134 100644
--- a/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_lora_sdxl_wds.py
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -872,7 +872,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
index 3be506352f..057b86eaaa 100644
--- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -842,7 +842,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index 5a28201bf7..09982f0546 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -78,7 +78,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -882,7 +882,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md
index 3b223c8c46..9976761739 100644
--- a/examples/controlnet/README.md
+++ b/examples/controlnet/README.md
@@ -359,7 +359,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
We encourage you to store or share your model with the community. To use huggingface hub, please login to your Hugging Face account, or ([create one](https://huggingface.co/docs/diffusers/main/en/training/hf.co/join) if you don’t have one already):
```sh
-huggingface-cli login
+hf auth login
```
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
diff --git a/examples/controlnet/README_flux.md b/examples/controlnet/README_flux.md
index fcac6df110..fefe0148a5 100644
--- a/examples/controlnet/README_flux.md
+++ b/examples/controlnet/README_flux.md
@@ -22,7 +22,7 @@ Here is a gpu memory consumption for reference, tested on a single A100 with 80G
> **Gated access**
>
-> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `huggingface-cli login`
+> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: `hf auth login`
## Running locally with PyTorch
@@ -88,7 +88,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
-Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
+Then run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
we can define the num_layers, num_single_layers, which determines the size of the control(default values are num_layers=4, num_single_layers=10)
diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md
index b62e33362d..9c2d6aaac3 100644
--- a/examples/controlnet/README_sd3.md
+++ b/examples/controlnet/README_sd3.md
@@ -56,7 +56,7 @@ First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stab
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) or [Stable Diffusion 3.5 Large Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
diff --git a/examples/controlnet/README_sdxl.md b/examples/controlnet/README_sdxl.md
index 75511385ff..442cfd386a 100644
--- a/examples/controlnet/README_sdxl.md
+++ b/examples/controlnet/README_sdxl.md
@@ -58,7 +58,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
-Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
+Then run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py
index 69bd39944a..c9be7a7f92 100644
--- a/examples/controlnet/train_controlnet.py
+++ b/examples/controlnet/train_controlnet.py
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -734,7 +734,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py
index 5561710d6f..2c08ffc49a 100644
--- a/examples/controlnet/train_controlnet_flax.py
+++ b/examples/controlnet/train_controlnet_flax.py
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = logging.getLogger(__name__)
@@ -665,7 +665,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging.basicConfig(
diff --git a/examples/controlnet/train_controlnet_flux.py b/examples/controlnet/train_controlnet_flux.py
index 94f030fe01..d281668e11 100644
--- a/examples/controlnet/train_controlnet_flux.py
+++ b/examples/controlnet/train_controlnet_flux.py
@@ -65,7 +65,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -814,7 +814,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_out_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py
index ecd7572ca3..033c9d7f26 100644
--- a/examples/controlnet/train_controlnet_sd3.py
+++ b/examples/controlnet/train_controlnet_sd3.py
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -928,7 +928,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
@@ -1330,7 +1330,7 @@ def main(args):
# controlnet(s) inference
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)
controlnet_image = vae.encode(controlnet_image).latent_dist.sample()
- controlnet_image = controlnet_image * vae.config.scaling_factor
+ controlnet_image = (controlnet_image - vae.config.shift_factor) * vae.config.scaling_factor
control_block_res_samples = controlnet(
hidden_states=noisy_model_input,
diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py
index 76d232da1c..3d182f8f4c 100644
--- a/examples/controlnet/train_controlnet_sdxl.py
+++ b/examples/controlnet/train_controlnet_sdxl.py
@@ -61,7 +61,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -829,7 +829,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py
index 81992c3dd1..ce4fec0a12 100644
--- a/examples/custom_diffusion/train_custom_diffusion.py
+++ b/examples/custom_diffusion/train_custom_diffusion.py
@@ -63,7 +63,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -663,7 +663,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md
index f0697609b3..c6c119ff97 100644
--- a/examples/dreambooth/README.md
+++ b/examples/dreambooth/README.md
@@ -330,7 +330,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/dreambooth/README_flux.md b/examples/dreambooth/README_flux.md
index a3704f2789..242f018b65 100644
--- a/examples/dreambooth/README_flux.md
+++ b/examples/dreambooth/README_flux.md
@@ -19,7 +19,7 @@ The `train_dreambooth_flux.py` script shows how to implement the training proced
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
@@ -260,5 +260,97 @@ to enable `latent_caching` simply pass `--cache_latents`.
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
+## Training Kontext
+
+[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
+provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
+
+**important**
+
+> [!NOTE]
+> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
+> To do this, execute the following steps in a new virtual environment:
+> ```
+> git clone https://github.com/huggingface/diffusers
+> cd diffusers
+> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
+> pip install -e .
+> ```
+
+Below is an example training command:
+
+```bash
+accelerate launch train_dreambooth_lora_flux_kontext.py \
+ --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
+ --instance_data_dir="dog" \
+ --output_dir="kontext-dog" \
+ --mixed_precision="bf16" \
+ --instance_prompt="a photo of sks dog" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --optimizer="adamw" \
+ --use_8bit_adam \
+ --cache_latents \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=0 \
+ --max_train_steps=500 \
+ --seed="0"
+```
+
+Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
+perform as expected.
+
+Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
+
+* Condition image
+* Target image
+* Instruction
+
+[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
+
+```bash
+accelerate launch train_dreambooth_lora_flux_kontext.py \
+ --pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
+ --output_dir="kontext-i2i" \
+ --dataset_name="kontext-community/relighting" \
+ --image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
+ --mixed_precision="bf16" \
+ --resolution=1024 \
+ --train_batch_size=1 \
+ --guidance_scale=1 \
+ --gradient_accumulation_steps=4 \
+ --gradient_checkpointing \
+ --optimizer="adamw" \
+ --use_8bit_adam \
+ --cache_latents \
+ --learning_rate=1e-4 \
+ --lr_scheduler="constant" \
+ --lr_warmup_steps=200 \
+ --max_train_steps=1000 \
+ --rank=16\
+ --seed="0"
+```
+
+More generally, when performing I2I fine-tuning, we expect you to:
+
+* Have a dataset `kontext-community/relighting`
+* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
+
+### Misc notes
+
+* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
+### Aspect Ratio Bucketing
+we've added aspect ratio bucketing support which allows training on images with different aspect ratios without cropping them to a single square resolution. This technique helps preserve the original composition of training images and can improve training efficiency.
+
+To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
+
+`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
+`
+Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
+
## Other notes
-Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
\ No newline at end of file
+Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
diff --git a/examples/dreambooth/README_hidream.md b/examples/dreambooth/README_hidream.md
index 2c6b68f3f6..58df99d9f6 100644
--- a/examples/dreambooth/README_hidream.md
+++ b/examples/dreambooth/README_hidream.md
@@ -95,7 +95,7 @@ accelerate launch train_dreambooth_lora_hidream.py \
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
-huggingface-cli login
+hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
diff --git a/examples/dreambooth/README_lumina2.md b/examples/dreambooth/README_lumina2.md
index f691acd266..d8998ccbed 100644
--- a/examples/dreambooth/README_lumina2.md
+++ b/examples/dreambooth/README_lumina2.md
@@ -101,7 +101,7 @@ accelerate launch train_dreambooth_lora_lumina2.py \
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
-huggingface-cli login
+hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
diff --git a/examples/dreambooth/README_sana.md b/examples/dreambooth/README_sana.md
index 1cc189149b..7972434b5e 100644
--- a/examples/dreambooth/README_sana.md
+++ b/examples/dreambooth/README_sana.md
@@ -101,7 +101,7 @@ accelerate launch train_dreambooth_lora_sana.py \
For using `push_to_hub`, make you're logged into your Hugging Face account:
```bash
-huggingface-cli login
+hf auth login
```
To better track our training experiments, we're using the following flags in the command above:
diff --git a/examples/dreambooth/README_sd3.md b/examples/dreambooth/README_sd3.md
index 5b706930e9..91d540a446 100644
--- a/examples/dreambooth/README_sd3.md
+++ b/examples/dreambooth/README_sd3.md
@@ -8,7 +8,7 @@ The `train_dreambooth_sd3.py` script shows how to implement the training procedu
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
diff --git a/examples/dreambooth/test_dreambooth_lora_flux_kontext.py b/examples/dreambooth/test_dreambooth_lora_flux_kontext.py
new file mode 100644
index 0000000000..c12fdd79ee
--- /dev/null
+++ b/examples/dreambooth/test_dreambooth_lora_flux_kontext.py
@@ -0,0 +1,281 @@
+# coding=utf-8
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import os
+import sys
+import tempfile
+
+import safetensors
+
+from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
+
+
+sys.path.append("..")
+from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
+
+
+logging.basicConfig(level=logging.DEBUG)
+
+logger = logging.getLogger()
+stream_handler = logging.StreamHandler(sys.stdout)
+logger.addHandler(stream_handler)
+
+
+class DreamBoothLoRAFluxKontext(ExamplesTestsAccelerate):
+ instance_data_dir = "docs/source/en/imgs"
+ instance_prompt = "photo"
+ pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-kontext-pipe"
+ script_path = "examples/dreambooth/train_dreambooth_lora_flux_kontext.py"
+ transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
+
+ def test_dreambooth_lora_flux_kontext(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_text_encoder_flux_kontext(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --train_text_encoder
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ starts_with_expected_prefix = all(
+ (key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_expected_prefix)
+
+ def test_dreambooth_lora_latent_caching(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names.
+ starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_layers(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --cache_latents
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lora_layers {self.transformer_layer_type}
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
+
+ # make sure the state_dict has the correct naming in the parameters.
+ lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ is_lora = all("lora" in k for k in lora_state_dict.keys())
+ self.assertTrue(is_lora)
+
+ # when not training the text encoder, all the parameters in the state dict should start
+ # with `"transformer"` in their names. In this test, we only params of
+ # transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
+ starts_with_transformer = all(
+ key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
+ )
+ self.assertTrue(starts_with_transformer)
+
+ def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=6
+ --checkpoints_total_limit=2
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual(
+ {x for x in os.listdir(tmpdir) if "checkpoint" in x},
+ {"checkpoint-4", "checkpoint-6"},
+ )
+
+ def test_dreambooth_lora_flux_kontext_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=4
+ --checkpointing_steps=2
+ """.split()
+
+ run_command(self._launch_args + test_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
+
+ resume_run_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path={self.pretrained_model_name_or_path}
+ --instance_data_dir={self.instance_data_dir}
+ --output_dir={tmpdir}
+ --instance_prompt={self.instance_prompt}
+ --resolution=64
+ --train_batch_size=1
+ --gradient_accumulation_steps=1
+ --max_train_steps=8
+ --checkpointing_steps=2
+ --resume_from_checkpoint=checkpoint-4
+ --checkpoints_total_limit=2
+ """.split()
+
+ run_command(self._launch_args + resume_run_args)
+
+ self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
+
+ def test_dreambooth_lora_with_metadata(self):
+ # Use a `lora_alpha` that is different from `rank`.
+ lora_alpha = 8
+ rank = 4
+ with tempfile.TemporaryDirectory() as tmpdir:
+ test_args = f"""
+ {self.script_path}
+ --pretrained_model_name_or_path {self.pretrained_model_name_or_path}
+ --instance_data_dir {self.instance_data_dir}
+ --instance_prompt {self.instance_prompt}
+ --resolution 64
+ --train_batch_size 1
+ --gradient_accumulation_steps 1
+ --max_train_steps 2
+ --lora_alpha={lora_alpha}
+ --rank={rank}
+ --learning_rate 5.0e-04
+ --scale_lr
+ --lr_scheduler constant
+ --lr_warmup_steps 0
+ --output_dir {tmpdir}
+ """.split()
+
+ run_command(self._launch_args + test_args)
+ # save_pretrained smoke test
+ state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
+ self.assertTrue(os.path.isfile(state_dict_file))
+
+ # Check if the metadata was properly serialized.
+ with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
+ metadata = f.metadata() or {}
+
+ metadata.pop("format", None)
+ raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
+ if raw:
+ raw = json.loads(raw)
+
+ loaded_lora_alpha = raw["transformer.lora_alpha"]
+ self.assertTrue(loaded_lora_alpha == lora_alpha)
+ loaded_lora_rank = raw["transformer.r"]
+ self.assertTrue(loaded_lora_rank == rank)
diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py
index ec0cc686b0..1807e9bd80 100644
--- a/examples/dreambooth/train_dreambooth.py
+++ b/examples/dreambooth/train_dreambooth.py
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -807,7 +807,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py
index 4e61a04f24..ccf4626cf8 100644
--- a/examples/dreambooth/train_dreambooth_flax.py
+++ b/examples/dreambooth/train_dreambooth_flax.py
@@ -35,7 +35,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py
index 02b83bb6b1..b3e7560251 100644
--- a/examples/dreambooth/train_dreambooth_flux.py
+++ b/examples/dreambooth/train_dreambooth_flux.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import copy
import gc
@@ -65,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -1013,7 +1027,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py
index 7c008970bd..aaf61f9813 100644
--- a/examples/dreambooth/train_dreambooth_lora.py
+++ b/examples/dreambooth/train_dreambooth_lora.py
@@ -74,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -756,7 +756,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py
index 9c529cbb92..6ec532e630 100644
--- a/examples/dreambooth/train_dreambooth_lora_flux.py
+++ b/examples/dreambooth/train_dreambooth_lora_flux.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=0.31.0",
+# "transformers>=4.41.2",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.11.1",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import copy
import itertools
@@ -72,7 +86,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -1051,7 +1065,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
new file mode 100644
index 0000000000..38896728fa
--- /dev/null
+++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py
@@ -0,0 +1,2208 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+
+import argparse
+import copy
+import itertools
+import logging
+import math
+import os
+import random
+import shutil
+import warnings
+from contextlib import nullcontext
+from pathlib import Path
+
+import numpy as np
+import torch
+import transformers
+from accelerate import Accelerator
+from accelerate.logging import get_logger
+from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
+from huggingface_hub import create_repo, upload_folder
+from huggingface_hub.utils import insecure_hashlib
+from peft import LoraConfig, set_peft_model_state_dict
+from peft.utils import get_peft_model_state_dict
+from PIL import Image
+from PIL.ImageOps import exif_transpose
+from torch.utils.data import Dataset
+from torch.utils.data.sampler import BatchSampler
+from torchvision import transforms
+from torchvision.transforms import functional as TF
+from tqdm.auto import tqdm
+from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
+
+import diffusers
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.optimization import get_scheduler
+from diffusers.training_utils import (
+ _collate_lora_metadata,
+ _set_state_dict_into_text_encoder,
+ cast_training_params,
+ compute_density_for_timestep_sampling,
+ compute_loss_weighting_for_sd3,
+ find_nearest_bucket,
+ free_memory,
+ parse_buckets_string,
+)
+from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
+from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
+from diffusers.utils.import_utils import is_torch_npu_available
+from diffusers.utils.torch_utils import is_compiled_module
+
+
+if is_wandb_available():
+ import wandb
+
+# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
+check_min_version("0.34.0.dev0")
+
+logger = get_logger(__name__)
+
+if is_torch_npu_available():
+ torch.npu.config.allow_internal_format = False
+
+
+def save_model_card(
+ repo_id: str,
+ images=None,
+ base_model: str = None,
+ train_text_encoder=False,
+ instance_prompt=None,
+ validation_prompt=None,
+ repo_folder=None,
+):
+ widget_dict = []
+ if images is not None:
+ for i, image in enumerate(images):
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
+ widget_dict.append(
+ {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
+ )
+
+ model_description = f"""
+# Flux Kontext DreamBooth LoRA - {repo_id}
+
+
+
+## Model description
+
+These are {repo_id} DreamBooth LoRA weights for {base_model}.
+
+The weights were trained using [DreamBooth](https://dreambooth.github.io/) with the [Flux diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_flux.md).
+
+Was LoRA for the text encoder enabled? {train_text_encoder}.
+
+## Trigger words
+
+You should use `{instance_prompt}` to trigger the image generation.
+
+## Download model
+
+[Download the *.safetensors LoRA]({repo_id}/tree/main) in the Files & versions tab.
+
+## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers)
+
+```py
+from diffusers import FluxKontextPipeline
+import torch
+pipeline = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to('cuda')
+pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
+image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
+```
+
+For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters)
+
+## License
+
+Please adhere to the licensing terms as described [here](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
+"""
+ model_card = load_or_create_model_card(
+ repo_id_or_path=repo_id,
+ from_training=True,
+ license="other",
+ base_model=base_model,
+ prompt=instance_prompt,
+ model_description=model_description,
+ widget=widget_dict,
+ )
+ tags = [
+ "text-to-image",
+ "diffusers-training",
+ "diffusers",
+ "lora",
+ "flux",
+ "flux-kontextflux-diffusers",
+ "template:sd-lora",
+ ]
+
+ model_card = populate_model_card(model_card, tags=tags)
+ model_card.save(os.path.join(repo_folder, "README.md"))
+
+
+def load_text_encoders(class_one, class_two):
+ text_encoder_one = class_one.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
+ )
+ text_encoder_two = class_two.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
+ )
+ return text_encoder_one, text_encoder_two
+
+
+def log_validation(
+ pipeline,
+ args,
+ accelerator,
+ pipeline_args,
+ epoch,
+ torch_dtype,
+ is_final_validation=False,
+):
+ logger.info(
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
+ f" {args.validation_prompt}."
+ )
+ pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
+ pipeline.set_progress_bar_config(disable=True)
+ pipeline_args_cp = pipeline_args.copy()
+
+ # run inference
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
+ autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
+
+ # pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
+ with torch.no_grad():
+ prompt = pipeline_args_cp.pop("prompt")
+ prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
+ images = []
+ for _ in range(args.num_validation_images):
+ with autocast_ctx:
+ image = pipeline(
+ **pipeline_args_cp,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ generator=generator,
+ ).images[0]
+ images.append(image)
+
+ for tracker in accelerator.trackers:
+ phase_name = "test" if is_final_validation else "validation"
+ if tracker.name == "tensorboard":
+ np_images = np.stack([np.asarray(img) for img in images])
+ tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
+ if tracker.name == "wandb":
+ tracker.log(
+ {
+ phase_name: [
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
+ ]
+ }
+ )
+
+ del pipeline
+ free_memory()
+
+ return images
+
+
+def import_model_class_from_model_name_or_path(
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
+):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
+ )
+ model_class = text_encoder_config.architectures[0]
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "T5EncoderModel":
+ from transformers import T5EncoderModel
+
+ return T5EncoderModel
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+
+def parse_args(input_args=None):
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
+ parser.add_argument(
+ "--pretrained_model_name_or_path",
+ type=str,
+ default=None,
+ required=True,
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--revision",
+ type=str,
+ default=None,
+ required=False,
+ help="Revision of pretrained model identifier from huggingface.co/models.",
+ )
+ parser.add_argument(
+ "--vae_encode_mode",
+ type=str,
+ default="mode",
+ choices=["sample", "mode"],
+ help="VAE encoding mode.",
+ )
+ parser.add_argument(
+ "--variant",
+ type=str,
+ default=None,
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
+ )
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default=None,
+ help=(
+ "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
+ " or to a folder containing files that 🤗 Datasets can understand."
+ ),
+ )
+ parser.add_argument(
+ "--dataset_config_name",
+ type=str,
+ default=None,
+ help="The config of the Dataset, leave as None if there's only one config.",
+ )
+ parser.add_argument(
+ "--instance_data_dir",
+ type=str,
+ default=None,
+ help=("A folder containing the training data. "),
+ )
+
+ parser.add_argument(
+ "--cache_dir",
+ type=str,
+ default=None,
+ help="The directory where the downloaded models and datasets will be stored.",
+ )
+
+ parser.add_argument(
+ "--image_column",
+ type=str,
+ default="image",
+ help="The column of the dataset containing the target image. By "
+ "default, the standard Image Dataset maps out 'file_name' "
+ "to 'image'.",
+ )
+ parser.add_argument(
+ "--cond_image_column",
+ type=str,
+ default=None,
+ help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
+ )
+ parser.add_argument(
+ "--caption_column",
+ type=str,
+ default=None,
+ help="The column of the dataset containing the instance prompt for each image",
+ )
+
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
+
+ parser.add_argument(
+ "--class_data_dir",
+ type=str,
+ default=None,
+ required=False,
+ help="A folder containing the training data of class images.",
+ )
+ parser.add_argument(
+ "--instance_prompt",
+ type=str,
+ default=None,
+ help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
+ )
+ parser.add_argument(
+ "--class_prompt",
+ type=str,
+ default=None,
+ help="The prompt to specify images in the same class as provided instance images.",
+ )
+ parser.add_argument(
+ "--max_sequence_length",
+ type=int,
+ default=512,
+ help="Maximum sequence length to use with with the T5 text encoder",
+ )
+ parser.add_argument(
+ "--validation_prompt",
+ type=str,
+ default=None,
+ help="A prompt that is used during validation to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--validation_image",
+ type=str,
+ default=None,
+ help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
+ )
+ parser.add_argument(
+ "--num_validation_images",
+ type=int,
+ default=4,
+ help="Number of images that should be generated during validation with `validation_prompt`.",
+ )
+ parser.add_argument(
+ "--validation_epochs",
+ type=int,
+ default=50,
+ help=(
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
+ ),
+ )
+ parser.add_argument(
+ "--rank",
+ type=int,
+ default=4,
+ help=("The dimension of the LoRA update matrices."),
+ )
+ parser.add_argument(
+ "--lora_alpha",
+ type=int,
+ default=4,
+ help="LoRA alpha to be used for additional scaling.",
+ )
+ parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
+
+ parser.add_argument(
+ "--with_prior_preservation",
+ default=False,
+ action="store_true",
+ help="Flag to add prior preservation loss.",
+ )
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
+ parser.add_argument(
+ "--num_class_images",
+ type=int,
+ default=100,
+ help=(
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
+ " class_data_dir, additional images will be sampled with class_prompt."
+ ),
+ )
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="flux-kontext-lora",
+ help="The output directory where the model predictions and checkpoints will be written.",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
+ parser.add_argument(
+ "--resolution",
+ type=int,
+ default=512,
+ help=(
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
+ " resolution"
+ ),
+ )
+ parser.add_argument(
+ "--aspect_ratio_buckets",
+ type=str,
+ default=None,
+ help=(
+ "Aspect ratio buckets to use for training. Define as a string of 'h1,w1;h2,w2;...'. "
+ "e.g. '1024,1024;768,1360;1360,768;880,1168;1168,880;1248,832;832,1248'"
+ "Images will be resized and cropped to fit the nearest bucket. If provided, --resolution is ignored."
+ ),
+ )
+ parser.add_argument(
+ "--center_crop",
+ default=False,
+ action="store_true",
+ help=(
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
+ " cropped. The images will be resized to the resolution first before cropping."
+ ),
+ )
+ parser.add_argument(
+ "--random_flip",
+ action="store_true",
+ help="whether to randomly flip images horizontally",
+ )
+ parser.add_argument(
+ "--train_text_encoder",
+ action="store_true",
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
+ )
+ parser.add_argument(
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
+ )
+ parser.add_argument(
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
+ )
+ parser.add_argument("--num_train_epochs", type=int, default=1)
+ parser.add_argument(
+ "--max_train_steps",
+ type=int,
+ default=None,
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
+ )
+ parser.add_argument(
+ "--checkpointing_steps",
+ type=int,
+ default=500,
+ help=(
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
+ " training using `--resume_from_checkpoint`."
+ ),
+ )
+ parser.add_argument(
+ "--checkpoints_total_limit",
+ type=int,
+ default=None,
+ help=("Max number of checkpoints to store."),
+ )
+ parser.add_argument(
+ "--resume_from_checkpoint",
+ type=str,
+ default=None,
+ help=(
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
+ ),
+ )
+ parser.add_argument(
+ "--gradient_accumulation_steps",
+ type=int,
+ default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
+ )
+ parser.add_argument(
+ "--gradient_checkpointing",
+ action="store_true",
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
+ )
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=1e-4,
+ help="Initial learning rate (after the potential warmup period) to use.",
+ )
+
+ parser.add_argument(
+ "--guidance_scale",
+ type=float,
+ default=3.5,
+ help="the FLUX.1 dev variant is a guidance distilled model",
+ )
+
+ parser.add_argument(
+ "--text_encoder_lr",
+ type=float,
+ default=5e-6,
+ help="Text encoder learning rate to use.",
+ )
+ parser.add_argument(
+ "--scale_lr",
+ action="store_true",
+ default=False,
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
+ )
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="constant",
+ help=(
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
+ ' "constant", "constant_with_warmup"]'
+ ),
+ )
+ parser.add_argument(
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
+ )
+ parser.add_argument(
+ "--lr_num_cycles",
+ type=int,
+ default=1,
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
+ )
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
+ parser.add_argument(
+ "--dataloader_num_workers",
+ type=int,
+ default=0,
+ help=(
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
+ ),
+ )
+ parser.add_argument(
+ "--weighting_scheme",
+ type=str,
+ default="none",
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
+ )
+ parser.add_argument(
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
+ )
+ parser.add_argument(
+ "--mode_scale",
+ type=float,
+ default=1.29,
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
+ )
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="AdamW",
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
+ )
+
+ parser.add_argument(
+ "--use_8bit_adam",
+ action="store_true",
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
+ )
+
+ parser.add_argument(
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
+ )
+ parser.add_argument(
+ "--prodigy_beta3",
+ type=float,
+ default=None,
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
+ parser.add_argument(
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
+ )
+
+ parser.add_argument(
+ "--lora_layers",
+ type=str,
+ default=None,
+ help=(
+ 'The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
+ ),
+ )
+
+ parser.add_argument(
+ "--adam_epsilon",
+ type=float,
+ default=1e-08,
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
+ )
+
+ parser.add_argument(
+ "--prodigy_use_bias_correction",
+ type=bool,
+ default=True,
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
+ )
+ parser.add_argument(
+ "--prodigy_safeguard_warmup",
+ type=bool,
+ default=True,
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
+ "Ignored if optimizer is adamW",
+ )
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
+ parser.add_argument(
+ "--hub_model_id",
+ type=str,
+ default=None,
+ help="The name of the repository to keep in sync with the local `output_dir`.",
+ )
+ parser.add_argument(
+ "--logging_dir",
+ type=str,
+ default="logs",
+ help=(
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
+ ),
+ )
+ parser.add_argument(
+ "--allow_tf32",
+ action="store_true",
+ help=(
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
+ ),
+ )
+ parser.add_argument(
+ "--cache_latents",
+ action="store_true",
+ default=False,
+ help="Cache the VAE latents",
+ )
+ parser.add_argument(
+ "--report_to",
+ type=str,
+ default="tensorboard",
+ help=(
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
+ ),
+ )
+ parser.add_argument(
+ "--mixed_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp16", "bf16"],
+ help=(
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
+ ),
+ )
+ parser.add_argument(
+ "--upcast_before_saving",
+ action="store_true",
+ default=False,
+ help=(
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
+ "Defaults to precision dtype used for training to save memory"
+ ),
+ )
+ parser.add_argument(
+ "--prior_generation_precision",
+ type=str,
+ default=None,
+ choices=["no", "fp32", "fp16", "bf16"],
+ help=(
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
+ ),
+ )
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
+
+ if input_args is not None:
+ args = parser.parse_args(input_args)
+ else:
+ args = parser.parse_args()
+
+ if args.dataset_name is None and args.instance_data_dir is None:
+ raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
+
+ if args.dataset_name is not None and args.instance_data_dir is not None:
+ raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
+
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
+ args.local_rank = env_local_rank
+
+ if args.with_prior_preservation:
+ if args.class_data_dir is None:
+ raise ValueError("You must specify a data directory for class images.")
+ if args.class_prompt is None:
+ raise ValueError("You must specify prompt for class images.")
+ if args.cond_image_column is not None:
+ raise ValueError("Prior preservation isn't supported with I2I training.")
+ else:
+ # logger is not available yet
+ if args.class_data_dir is not None:
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
+ if args.class_prompt is not None:
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
+
+ if args.cond_image_column is not None:
+ assert args.image_column is not None
+ assert args.caption_column is not None
+ assert args.dataset_name is not None
+ assert not args.train_text_encoder
+ if args.validation_prompt is not None:
+ assert args.validation_image is None and os.path.exists(args.validation_image)
+
+ return args
+
+
+class DreamBoothDataset(Dataset):
+ """
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
+ It pre-processes the images.
+ """
+
+ def __init__(
+ self,
+ instance_data_root,
+ instance_prompt,
+ class_prompt,
+ class_data_root=None,
+ class_num=None,
+ repeats=1,
+ center_crop=False,
+ buckets=None,
+ args=None,
+ ):
+ self.center_crop = center_crop
+
+ self.instance_prompt = instance_prompt
+ self.custom_instance_prompts = None
+ self.class_prompt = class_prompt
+
+ self.buckets = buckets
+
+ # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
+ # we load the training data using load_dataset
+ if args.dataset_name is not None:
+ try:
+ from datasets import load_dataset
+ except ImportError:
+ raise ImportError(
+ "You are trying to load your data using the datasets library. If you wish to train using custom "
+ "captions please install the datasets library: `pip install datasets`. If you wish to load a "
+ "local folder containing images only, specify --instance_data_dir instead."
+ )
+ # Downloading and loading a dataset from the hub.
+ # See more about loading custom images at
+ # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
+ dataset = load_dataset(
+ args.dataset_name,
+ args.dataset_config_name,
+ cache_dir=args.cache_dir,
+ )
+ # Preprocessing the datasets.
+ column_names = dataset["train"].column_names
+
+ # 6. Get the column names for input/target.
+ if args.cond_image_column is not None and args.cond_image_column not in column_names:
+ raise ValueError(
+ f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ if args.image_column is None:
+ image_column = column_names[0]
+ logger.info(f"image column defaulting to {image_column}")
+ else:
+ image_column = args.image_column
+ if image_column not in column_names:
+ raise ValueError(
+ f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
+ cond_images = None
+ cond_image_column = args.cond_image_column
+ if cond_image_column is not None:
+ cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
+ assert len(instance_images) == len(cond_images)
+
+ if args.caption_column is None:
+ logger.info(
+ "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
+ "contains captions/prompts for the images, make sure to specify the "
+ "column as --caption_column"
+ )
+ self.custom_instance_prompts = None
+ else:
+ if args.caption_column not in column_names:
+ raise ValueError(
+ f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
+ )
+ custom_instance_prompts = dataset["train"][args.caption_column]
+ # create final list of captions according to --repeats
+ self.custom_instance_prompts = []
+ for caption in custom_instance_prompts:
+ self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
+ else:
+ self.instance_data_root = Path(instance_data_root)
+ if not self.instance_data_root.exists():
+ raise ValueError("Instance images root doesn't exists.")
+
+ instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
+ self.custom_instance_prompts = None
+
+ self.instance_images = []
+ self.cond_images = []
+ for i, img in enumerate(instance_images):
+ self.instance_images.extend(itertools.repeat(img, repeats))
+ if args.dataset_name is not None and cond_images is not None:
+ self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
+
+ self.pixel_values = []
+ self.cond_pixel_values = []
+ for i, image in enumerate(self.instance_images):
+ image = exif_transpose(image)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ dest_image = None
+ if self.cond_images:
+ dest_image = exif_transpose(self.cond_images[i])
+ if not dest_image.mode == "RGB":
+ dest_image = dest_image.convert("RGB")
+
+ width, height = image.size
+
+ # Find the closest bucket
+ bucket_idx = find_nearest_bucket(height, width, self.buckets)
+ target_height, target_width = self.buckets[bucket_idx]
+ self.size = (target_height, target_width)
+
+ # based on the bucket assignment, define the transformations
+ image, dest_image = self.paired_transform(
+ image,
+ dest_image=dest_image,
+ size=self.size,
+ center_crop=args.center_crop,
+ random_flip=args.random_flip,
+ )
+ self.pixel_values.append((image, bucket_idx))
+ if dest_image is not None:
+ self.cond_pixel_values.append((dest_image, bucket_idx))
+
+ self.num_instance_images = len(self.instance_images)
+ self._length = self.num_instance_images
+
+ if class_data_root is not None:
+ self.class_data_root = Path(class_data_root)
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
+ self.class_images_path = list(self.class_data_root.iterdir())
+ if class_num is not None:
+ self.num_class_images = min(len(self.class_images_path), class_num)
+ else:
+ self.num_class_images = len(self.class_images_path)
+ self._length = max(self.num_class_images, self.num_instance_images)
+ else:
+ self.class_data_root = None
+
+ self.image_transforms = transforms.Compose(
+ [
+ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
+ transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size),
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+ )
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ example = {}
+ instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
+ example["instance_images"] = instance_image
+ example["bucket_idx"] = bucket_idx
+ if self.cond_pixel_values:
+ dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
+ example["cond_images"] = dest_image
+
+ if self.custom_instance_prompts:
+ caption = self.custom_instance_prompts[index % self.num_instance_images]
+ if caption:
+ example["instance_prompt"] = caption
+ else:
+ example["instance_prompt"] = self.instance_prompt
+
+ else: # custom prompts were provided, but length does not match size of image dataset
+ example["instance_prompt"] = self.instance_prompt
+
+ if self.class_data_root:
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
+ class_image = exif_transpose(class_image)
+
+ if not class_image.mode == "RGB":
+ class_image = class_image.convert("RGB")
+ example["class_images"] = self.image_transforms(class_image)
+ example["class_prompt"] = self.class_prompt
+
+ return example
+
+ def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
+ # 1. Resize (deterministic)
+ resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
+ image = resize(image)
+ if dest_image is not None:
+ dest_image = resize(dest_image)
+
+ # 2. Crop: either center or SAME random crop
+ if center_crop:
+ crop = transforms.CenterCrop(size)
+ image = crop(image)
+ if dest_image is not None:
+ dest_image = crop(dest_image)
+ else:
+ # get_params returns (i, j, h, w)
+ i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
+ image = TF.crop(image, i, j, h, w)
+ if dest_image is not None:
+ dest_image = TF.crop(dest_image, i, j, h, w)
+
+ # 3. Random horizontal flip with the SAME coin flip
+ if random_flip:
+ do_flip = random.random() < 0.5
+ if do_flip:
+ image = TF.hflip(image)
+ if dest_image is not None:
+ dest_image = TF.hflip(dest_image)
+
+ # 4. ToTensor + Normalize (deterministic)
+ to_tensor = transforms.ToTensor()
+ normalize = transforms.Normalize([0.5], [0.5])
+ image = normalize(to_tensor(image))
+ if dest_image is not None:
+ dest_image = normalize(to_tensor(dest_image))
+
+ return (image, dest_image) if dest_image is not None else (image, None)
+
+
+def collate_fn(examples, with_prior_preservation=False):
+ pixel_values = [example["instance_images"] for example in examples]
+ prompts = [example["instance_prompt"] for example in examples]
+
+ # Concat class and instance examples for prior preservation.
+ # We do this to avoid doing two forward passes.
+ if with_prior_preservation:
+ pixel_values += [example["class_images"] for example in examples]
+ prompts += [example["class_prompt"] for example in examples]
+
+ pixel_values = torch.stack(pixel_values)
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
+
+ batch = {"pixel_values": pixel_values, "prompts": prompts}
+ if any("cond_images" in example for example in examples):
+ cond_pixel_values = [example["cond_images"] for example in examples]
+ cond_pixel_values = torch.stack(cond_pixel_values)
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
+ batch.update({"cond_pixel_values": cond_pixel_values})
+ return batch
+
+
+class BucketBatchSampler(BatchSampler):
+ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
+ if not isinstance(batch_size, int) or batch_size <= 0:
+ raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
+ if not isinstance(drop_last, bool):
+ raise ValueError("drop_last should be a boolean value, but got drop_last={}".format(drop_last))
+
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self.drop_last = drop_last
+
+ # Group indices by bucket
+ self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
+ for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
+ self.bucket_indices[bucket_idx].append(idx)
+
+ self.sampler_len = 0
+ self.batches = []
+
+ # Pre-generate batches for each bucket
+ for indices_in_bucket in self.bucket_indices:
+ # Shuffle indices within the bucket
+ random.shuffle(indices_in_bucket)
+ # Create batches
+ for i in range(0, len(indices_in_bucket), self.batch_size):
+ batch = indices_in_bucket[i : i + self.batch_size]
+ if len(batch) < self.batch_size and self.drop_last:
+ continue # Skip partial batch if drop_last is True
+ self.batches.append(batch)
+ self.sampler_len += 1 # Count the number of batches
+
+ def __iter__(self):
+ # Shuffle the order of the batches each epoch
+ random.shuffle(self.batches)
+ for batch in self.batches:
+ yield batch
+
+ def __len__(self):
+ return self.sampler_len
+
+
+class PromptDataset(Dataset):
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
+
+ def __init__(self, prompt, num_samples):
+ self.prompt = prompt
+ self.num_samples = num_samples
+
+ def __len__(self):
+ return self.num_samples
+
+ def __getitem__(self, index):
+ example = {}
+ example["prompt"] = self.prompt
+ example["index"] = index
+ return example
+
+
+def tokenize_prompt(tokenizer, prompt, max_sequence_length):
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ return text_input_ids
+
+
+def _encode_prompt_with_t5(
+ text_encoder,
+ tokenizer,
+ max_sequence_length=512,
+ prompt=None,
+ num_images_per_prompt=1,
+ device=None,
+ text_input_ids=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
+
+ if hasattr(text_encoder, "module"):
+ dtype = text_encoder.module.dtype
+ else:
+ dtype = text_encoder.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+
+def _encode_prompt_with_clip(
+ text_encoder,
+ tokenizer,
+ prompt: str,
+ device=None,
+ text_input_ids=None,
+ num_images_per_prompt: int = 1,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if tokenizer is not None:
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=77,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ else:
+ if text_input_ids is None:
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ if hasattr(text_encoder, "module"):
+ dtype = text_encoder.module.dtype
+ else:
+ dtype = text_encoder.dtype
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+
+def encode_prompt(
+ text_encoders,
+ tokenizers,
+ prompt: str,
+ max_sequence_length,
+ device=None,
+ num_images_per_prompt: int = 1,
+ text_input_ids_list=None,
+):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if hasattr(text_encoders[0], "module"):
+ dtype = text_encoders[0].module.dtype
+ else:
+ dtype = text_encoders[0].dtype
+
+ pooled_prompt_embeds = _encode_prompt_with_clip(
+ text_encoder=text_encoders[0],
+ tokenizer=tokenizers[0],
+ prompt=prompt,
+ device=device if device is not None else text_encoders[0].device,
+ num_images_per_prompt=num_images_per_prompt,
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
+ )
+
+ prompt_embeds = _encode_prompt_with_t5(
+ text_encoder=text_encoders[1],
+ tokenizer=tokenizers[1],
+ max_sequence_length=max_sequence_length,
+ prompt=prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device if device is not None else text_encoders[1].device,
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
+ )
+
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+
+def main(args):
+ if args.report_to == "wandb" and args.hub_token is not None:
+ raise ValueError(
+ "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
+ " Please use `hf auth login` to authenticate with the Hub."
+ )
+
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ logging_dir = Path(args.output_dir, args.logging_dir)
+
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision,
+ log_with=args.report_to,
+ project_config=accelerator_project_config,
+ kwargs_handlers=[kwargs],
+ )
+
+ # Disable AMP for MPS.
+ if torch.backends.mps.is_available():
+ accelerator.native_amp = False
+
+ if args.report_to == "wandb":
+ if not is_wandb_available():
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
+
+ # Make one log on every process with the configuration for debugging.
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ level=logging.INFO,
+ )
+ logger.info(accelerator.state, main_process_only=False)
+ if accelerator.is_local_main_process:
+ transformers.utils.logging.set_verbosity_warning()
+ diffusers.utils.logging.set_verbosity_info()
+ else:
+ transformers.utils.logging.set_verbosity_error()
+ diffusers.utils.logging.set_verbosity_error()
+
+ # If passed along, set the training seed now.
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # Generate class images if prior preservation is enabled.
+ if args.with_prior_preservation:
+ class_images_dir = Path(args.class_data_dir)
+ if not class_images_dir.exists():
+ class_images_dir.mkdir(parents=True)
+ cur_class_images = len(list(class_images_dir.iterdir()))
+
+ if cur_class_images < args.num_class_images:
+ has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
+ torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
+ if args.prior_generation_precision == "fp32":
+ torch_dtype = torch.float32
+ elif args.prior_generation_precision == "fp16":
+ torch_dtype = torch.float16
+ elif args.prior_generation_precision == "bf16":
+ torch_dtype = torch.bfloat16
+
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="transformer",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline = FluxKontextPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ torch_dtype=torch_dtype,
+ revision=args.revision,
+ variant=args.variant,
+ )
+ pipeline.set_progress_bar_config(disable=True)
+
+ num_new_images = args.num_class_images - cur_class_images
+ logger.info(f"Number of class images to sample: {num_new_images}.")
+
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
+
+ sample_dataloader = accelerator.prepare(sample_dataloader)
+ pipeline.to(accelerator.device)
+
+ for example in tqdm(
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
+ ):
+ images = pipeline(example["prompt"]).images
+
+ for i, image in enumerate(images):
+ hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
+ image.save(image_filename)
+
+ del pipeline
+ free_memory()
+
+ # Handle the repository creation
+ if accelerator.is_main_process:
+ if args.output_dir is not None:
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ if args.push_to_hub:
+ repo_id = create_repo(
+ repo_id=args.hub_model_id or Path(args.output_dir).name,
+ exist_ok=True,
+ ).repo_id
+
+ # Load the tokenizers
+ tokenizer_one = CLIPTokenizer.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer",
+ revision=args.revision,
+ )
+ tokenizer_two = T5TokenizerFast.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="tokenizer_2",
+ revision=args.revision,
+ )
+
+ # import correct text encoder classes
+ text_encoder_cls_one = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision
+ )
+ text_encoder_cls_two = import_model_class_from_model_name_or_path(
+ args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
+ )
+
+ # Load scheduler and models
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="scheduler"
+ )
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ vae = AutoencoderKL.from_pretrained(
+ args.pretrained_model_name_or_path,
+ subfolder="vae",
+ revision=args.revision,
+ variant=args.variant,
+ )
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+
+ # We only train the additional adapter LoRA layers
+ transformer.requires_grad_(False)
+ vae.requires_grad_(False)
+ text_encoder_one.requires_grad_(False)
+ text_encoder_two.requires_grad_(False)
+
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
+ # as these weights are only used for inference, keeping weights in full precision is not required.
+ weight_dtype = torch.float32
+ if accelerator.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif accelerator.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
+ # due to pytorch#99272, MPS does not yet support bfloat16.
+ raise ValueError(
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
+ )
+
+ vae.to(accelerator.device, dtype=weight_dtype)
+ transformer.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
+
+ if args.gradient_checkpointing:
+ transformer.enable_gradient_checkpointing()
+ if args.train_text_encoder:
+ text_encoder_one.gradient_checkpointing_enable()
+
+ if args.lora_layers is not None:
+ target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
+ else:
+ target_modules = [
+ "attn.to_k",
+ "attn.to_q",
+ "attn.to_v",
+ "attn.to_out.0",
+ "attn.add_k_proj",
+ "attn.add_q_proj",
+ "attn.add_v_proj",
+ "attn.to_add_out",
+ "ff.net.0.proj",
+ "ff.net.2",
+ "ff_context.net.0.proj",
+ "ff_context.net.2",
+ "proj_mlp",
+ ]
+
+ # now we will add new LoRA weights the transformer layers
+ transformer_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=target_modules,
+ )
+ transformer.add_adapter(transformer_lora_config)
+ if args.train_text_encoder:
+ text_lora_config = LoraConfig(
+ r=args.rank,
+ lora_alpha=args.lora_alpha,
+ lora_dropout=args.lora_dropout,
+ init_lora_weights="gaussian",
+ target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
+ )
+ text_encoder_one.add_adapter(text_lora_config)
+
+ def unwrap_model(model):
+ model = accelerator.unwrap_model(model)
+ model = model._orig_mod if is_compiled_module(model) else model
+ return model
+
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
+ def save_model_hook(models, weights, output_dir):
+ if accelerator.is_main_process:
+ transformer_lora_layers_to_save = None
+ text_encoder_one_lora_layers_to_save = None
+ modules_to_save = {}
+ for model in models:
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["transformer"] = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
+ modules_to_save["text_encoder"] = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ # make sure to pop weight so that corresponding model is not saved again
+ weights.pop()
+
+ FluxKontextPipeline.save_lora_weights(
+ output_dir,
+ transformer_lora_layers=transformer_lora_layers_to_save,
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ def load_model_hook(models, input_dir):
+ transformer_ = None
+ text_encoder_one_ = None
+
+ while len(models) > 0:
+ model = models.pop()
+
+ if isinstance(model, type(unwrap_model(transformer))):
+ transformer_ = model
+ elif isinstance(model, type(unwrap_model(text_encoder_one))):
+ text_encoder_one_ = model
+ else:
+ raise ValueError(f"unexpected save model: {model.__class__}")
+
+ lora_state_dict = FluxKontextPipeline.lora_state_dict(input_dir)
+
+ transformer_state_dict = {
+ f"{k.replace('transformer.', '')}": v for k, v in lora_state_dict.items() if k.startswith("transformer.")
+ }
+ transformer_state_dict = convert_unet_state_dict_to_peft(transformer_state_dict)
+ incompatible_keys = set_peft_model_state_dict(transformer_, transformer_state_dict, adapter_name="default")
+ if incompatible_keys is not None:
+ # check only for unexpected keys
+ unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
+ if unexpected_keys:
+ logger.warning(
+ f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
+ f" {unexpected_keys}. "
+ )
+ if args.train_text_encoder:
+ # Do we need to call `scale_lora_layers()` here?
+ _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_)
+
+ # Make sure the trainable params are in float32. This is again needed since the base models
+ # are in `weight_dtype`. More details:
+ # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
+ if args.mixed_precision == "fp16":
+ models = [transformer_]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one_])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models)
+
+ accelerator.register_save_state_pre_hook(save_model_hook)
+ accelerator.register_load_state_pre_hook(load_model_hook)
+
+ # Enable TF32 for faster training on Ampere GPUs,
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
+ if args.allow_tf32 and torch.cuda.is_available():
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+ if args.scale_lr:
+ args.learning_rate = (
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
+ )
+
+ # Make sure the trainable params are in float32.
+ if args.mixed_precision == "fp16":
+ models = [transformer]
+ if args.train_text_encoder:
+ models.extend([text_encoder_one])
+ # only upcast trainable parameters (LoRA) into fp32
+ cast_training_params(models, dtype=torch.float32)
+
+ transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
+ if args.train_text_encoder:
+ text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters()))
+
+ # Optimization parameters
+ transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
+ if args.train_text_encoder:
+ # different learning rate for text encoder and unet
+ text_parameters_one_with_lr = {
+ "params": text_lora_parameters_one,
+ "weight_decay": args.adam_weight_decay_text_encoder,
+ "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
+ }
+ params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
+ else:
+ params_to_optimize = [transformer_parameters_with_lr]
+
+ # Optimizer creation
+ if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
+ logger.warning(
+ f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
+ "Defaulting to adamW"
+ )
+ args.optimizer = "adamw"
+
+ if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
+ logger.warning(
+ f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
+ f"set to {args.optimizer.lower()}"
+ )
+
+ if args.optimizer.lower() == "adamw":
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError(
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
+ )
+
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ )
+
+ if args.optimizer.lower() == "prodigy":
+ try:
+ import prodigyopt
+ except ImportError:
+ raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
+
+ optimizer_class = prodigyopt.Prodigy
+
+ if args.learning_rate <= 0.1:
+ logger.warning(
+ "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
+ )
+ if args.train_text_encoder and args.text_encoder_lr:
+ logger.warning(
+ f"Learning rates were provided both for the transformer and the text encoder- e.g. text_encoder_lr:"
+ f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
+ f"When using prodigy only learning_rate is used as the initial learning rate."
+ )
+ # changes the learning rate of text_encoder_parameters_one to be
+ # --learning_rate
+ params_to_optimize[1]["lr"] = args.learning_rate
+
+ optimizer = optimizer_class(
+ params_to_optimize,
+ betas=(args.adam_beta1, args.adam_beta2),
+ beta3=args.prodigy_beta3,
+ weight_decay=args.adam_weight_decay,
+ eps=args.adam_epsilon,
+ decouple=args.prodigy_decouple,
+ use_bias_correction=args.prodigy_use_bias_correction,
+ safeguard_warmup=args.prodigy_safeguard_warmup,
+ )
+
+ if args.aspect_ratio_buckets is not None:
+ buckets = parse_buckets_string(args.aspect_ratio_buckets)
+ else:
+ buckets = [(args.resolution, args.resolution)]
+ logger.info(f"Using parsed aspect ratio buckets: {buckets}")
+
+ # Dataset and DataLoaders creation:
+ train_dataset = DreamBoothDataset(
+ instance_data_root=args.instance_data_dir,
+ instance_prompt=args.instance_prompt,
+ class_prompt=args.class_prompt,
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
+ class_num=args.num_class_images,
+ buckets=buckets,
+ repeats=args.repeats,
+ center_crop=args.center_crop,
+ args=args,
+ )
+ if args.cond_image_column is not None:
+ logger.info("I2I fine-tuning enabled.")
+ batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset,
+ batch_sampler=batch_sampler,
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
+ num_workers=args.dataloader_num_workers,
+ )
+
+ if not args.train_text_encoder:
+ tokenizers = [tokenizer_one, tokenizer_two]
+ text_encoders = [text_encoder_one, text_encoder_two]
+
+ def compute_text_embeddings(prompt, text_encoders, tokenizers):
+ with torch.no_grad():
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders, tokenizers, prompt, args.max_sequence_length
+ )
+ prompt_embeds = prompt_embeds.to(accelerator.device)
+ pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
+ text_ids = text_ids.to(accelerator.device)
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
+ # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
+ # the redundant encoding.
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ instance_prompt_hidden_states, instance_pooled_prompt_embeds, instance_text_ids = compute_text_embeddings(
+ args.instance_prompt, text_encoders, tokenizers
+ )
+
+ # Handle class prompt for prior-preservation.
+ if args.with_prior_preservation:
+ if not args.train_text_encoder:
+ class_prompt_hidden_states, class_pooled_prompt_embeds, class_text_ids = compute_text_embeddings(
+ args.class_prompt, text_encoders, tokenizers
+ )
+
+ # Clear the memory here
+ if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
+ del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
+ free_memory()
+
+ # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
+ # pack the statically computed variables appropriately here. This is so that we don't
+ # have to pass them to the dataloader.
+
+ if not train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds = instance_prompt_hidden_states
+ pooled_prompt_embeds = instance_pooled_prompt_embeds
+ text_ids = instance_text_ids
+ if args.with_prior_preservation:
+ prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
+ pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, class_pooled_prompt_embeds], dim=0)
+ text_ids = torch.cat([text_ids, class_text_ids], dim=0)
+ # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts)
+ # we need to tokenize and encode the batch prompts on all training steps
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt, max_sequence_length=77)
+ tokens_two = tokenize_prompt(
+ tokenizer_two, args.instance_prompt, max_sequence_length=args.max_sequence_length
+ )
+ if args.with_prior_preservation:
+ class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt, max_sequence_length=77)
+ class_tokens_two = tokenize_prompt(
+ tokenizer_two, args.class_prompt, max_sequence_length=args.max_sequence_length
+ )
+ tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
+ tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
+
+ elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
+ cached_text_embeddings = []
+ for batch in tqdm(train_dataloader, desc="Embedding prompts"):
+ batch_prompts = batch["prompts"]
+ prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
+ batch_prompts, text_encoders, tokenizers
+ )
+ cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
+
+ if args.validation_prompt is None:
+ text_encoder_one.cpu(), text_encoder_two.cpu()
+ del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
+ free_memory()
+
+ vae_config_shift_factor = vae.config.shift_factor
+ vae_config_scaling_factor = vae.config.scaling_factor
+ vae_config_block_out_channels = vae.config.block_out_channels
+ has_image_input = args.cond_image_column is not None
+ if args.cache_latents:
+ latents_cache = []
+ cond_latents_cache = []
+ for batch in tqdm(train_dataloader, desc="Caching latents"):
+ with torch.no_grad():
+ batch["pixel_values"] = batch["pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
+ if has_image_input:
+ batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
+ accelerator.device, non_blocking=True, dtype=weight_dtype
+ )
+ cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
+
+ if args.validation_prompt is None:
+ vae.cpu()
+ del vae
+ free_memory()
+
+ # Scheduler and math around the number of training steps.
+ # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
+ num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
+ if args.max_train_steps is None:
+ len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
+ num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
+ num_training_steps_for_scheduler = (
+ args.num_train_epochs * accelerator.num_processes * num_update_steps_per_epoch
+ )
+ else:
+ num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
+
+ lr_scheduler = get_scheduler(
+ args.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=num_warmup_steps_for_scheduler,
+ num_training_steps=num_training_steps_for_scheduler,
+ num_cycles=args.lr_num_cycles,
+ power=args.lr_power,
+ )
+
+ # Prepare everything with our `accelerator`.
+ if args.train_text_encoder:
+ (
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ ) = accelerator.prepare(
+ transformer,
+ text_encoder_one,
+ optimizer,
+ train_dataloader,
+ lr_scheduler,
+ )
+ else:
+ transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ transformer, optimizer, train_dataloader, lr_scheduler
+ )
+
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ if args.max_train_steps is None:
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
+ if num_training_steps_for_scheduler != args.max_train_steps:
+ logger.warning(
+ f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
+ f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
+ f"This inconsistency may result in the learning rate scheduler not functioning properly."
+ )
+ # Afterwards we recalculate our number of training epochs
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # We need to initialize the trackers we use, and also store our configuration.
+ # The trackers initializes automatically on the main process.
+ if accelerator.is_main_process:
+ tracker_name = "dreambooth-flux-kontext-lora"
+ accelerator.init_trackers(tracker_name, config=vars(args))
+
+ # Train!
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {len(train_dataset)}")
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
+ global_step = 0
+ first_epoch = 0
+
+ # Potentially load in the weights and states from a previous save
+ if args.resume_from_checkpoint:
+ if args.resume_from_checkpoint != "latest":
+ path = os.path.basename(args.resume_from_checkpoint)
+ else:
+ # Get the mos recent checkpoint
+ dirs = os.listdir(args.output_dir)
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
+ path = dirs[-1] if len(dirs) > 0 else None
+
+ if path is None:
+ accelerator.print(
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
+ )
+ args.resume_from_checkpoint = None
+ initial_global_step = 0
+ else:
+ accelerator.print(f"Resuming from checkpoint {path}")
+ accelerator.load_state(os.path.join(args.output_dir, path))
+ global_step = int(path.split("-")[1])
+
+ initial_global_step = global_step
+ first_epoch = global_step // num_update_steps_per_epoch
+
+ else:
+ initial_global_step = 0
+
+ progress_bar = tqdm(
+ range(0, args.max_train_steps),
+ initial=initial_global_step,
+ desc="Steps",
+ # Only show the progress bar once on each machine.
+ disable=not accelerator.is_local_main_process,
+ )
+
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
+ timesteps = timesteps.to(accelerator.device)
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
+
+ sigma = sigmas[step_indices].flatten()
+ while len(sigma.shape) < n_dim:
+ sigma = sigma.unsqueeze(-1)
+ return sigma
+
+ has_guidance = unwrap_model(transformer).config.guidance_embeds
+ for epoch in range(first_epoch, args.num_train_epochs):
+ transformer.train()
+ if args.train_text_encoder:
+ text_encoder_one.train()
+ # set top parameter requires_grad = True for gradient checkpointing works
+ unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True)
+
+ for step, batch in enumerate(train_dataloader):
+ models_to_accumulate = [transformer]
+ if args.train_text_encoder:
+ models_to_accumulate.extend([text_encoder_one])
+ with accelerator.accumulate(models_to_accumulate):
+ prompts = batch["prompts"]
+
+ # encode batch prompts when custom prompts are provided for each image -
+ if train_dataset.custom_instance_prompts:
+ if not args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
+ else:
+ tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
+ tokens_two = tokenize_prompt(
+ tokenizer_two, prompts, max_sequence_length=args.max_sequence_length
+ )
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[tokens_one, tokens_two],
+ max_sequence_length=args.max_sequence_length,
+ device=accelerator.device,
+ prompt=prompts,
+ )
+ else:
+ elems_to_repeat = len(prompts)
+ if args.train_text_encoder:
+ prompt_embeds, pooled_prompt_embeds, text_ids = encode_prompt(
+ text_encoders=[text_encoder_one, text_encoder_two],
+ tokenizers=[None, None],
+ text_input_ids_list=[
+ tokens_one.repeat(elems_to_repeat, 1),
+ tokens_two.repeat(elems_to_repeat, 1),
+ ],
+ max_sequence_length=args.max_sequence_length,
+ device=accelerator.device,
+ prompt=args.instance_prompt,
+ )
+
+ # Convert images to latent space
+ if args.cache_latents:
+ if args.vae_encode_mode == "sample":
+ model_input = latents_cache[step].sample()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].sample()
+ else:
+ model_input = latents_cache[step].mode()
+ if has_image_input:
+ cond_model_input = cond_latents_cache[step].mode()
+ else:
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ if has_image_input:
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
+ if args.vae_encode_mode == "sample":
+ model_input = vae.encode(pixel_values).latent_dist.sample()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
+ else:
+ model_input = vae.encode(pixel_values).latent_dist.mode()
+ if has_image_input:
+ cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ model_input = model_input.to(dtype=weight_dtype)
+ if has_image_input:
+ cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
+ cond_model_input = cond_model_input.to(dtype=weight_dtype)
+
+ vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
+
+ latent_image_ids = FluxKontextPipeline._prepare_latent_image_ids(
+ model_input.shape[0],
+ model_input.shape[2] // 2,
+ model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ if has_image_input:
+ cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
+ cond_model_input.shape[0],
+ cond_model_input.shape[2] // 2,
+ cond_model_input.shape[3] // 2,
+ accelerator.device,
+ weight_dtype,
+ )
+ cond_latents_ids[..., 0] = 1
+ latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(model_input)
+ bsz = model_input.shape[0]
+
+ # Sample a random timestep for each image
+ # for weighting schemes where we sample timesteps non-uniformly
+ u = compute_density_for_timestep_sampling(
+ weighting_scheme=args.weighting_scheme,
+ batch_size=bsz,
+ logit_mean=args.logit_mean,
+ logit_std=args.logit_std,
+ mode_scale=args.mode_scale,
+ )
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
+
+ # Add noise according to flow matching.
+ # zt = (1 - texp) * x + texp * z1
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
+ packed_noisy_model_input = FluxKontextPipeline._pack_latents(
+ noisy_model_input,
+ batch_size=model_input.shape[0],
+ num_channels_latents=model_input.shape[1],
+ height=model_input.shape[2],
+ width=model_input.shape[3],
+ )
+ orig_inp_shape = packed_noisy_model_input.shape
+ if has_image_input:
+ packed_cond_input = FluxKontextPipeline._pack_latents(
+ cond_model_input,
+ batch_size=cond_model_input.shape[0],
+ num_channels_latents=cond_model_input.shape[1],
+ height=cond_model_input.shape[2],
+ width=cond_model_input.shape[3],
+ )
+ packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
+
+ # Kontext always has guidance
+ guidance = None
+ if has_guidance:
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
+ guidance = guidance.expand(model_input.shape[0])
+
+ # Predict the noise residual
+ model_pred = transformer(
+ hidden_states=packed_noisy_model_input,
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transformer model (we should not keep it but I want to keep the inputs same for the model for testing)
+ timestep=timesteps / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_image_ids,
+ return_dict=False,
+ )[0]
+ if has_image_input:
+ model_pred = model_pred[:, : orig_inp_shape[1]]
+ model_pred = FluxKontextPipeline._unpack_latents(
+ model_pred,
+ height=model_input.shape[2] * vae_scale_factor,
+ width=model_input.shape[3] * vae_scale_factor,
+ vae_scale_factor=vae_scale_factor,
+ )
+
+ # these weighting schemes use a uniform timestep sampling
+ # and instead post-weight the loss
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
+
+ # flow matching loss
+ target = noise - model_input
+
+ if args.with_prior_preservation:
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
+ target, target_prior = torch.chunk(target, 2, dim=0)
+
+ # Compute prior loss
+ prior_loss = torch.mean(
+ (weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
+ target_prior.shape[0], -1
+ ),
+ 1,
+ )
+ prior_loss = prior_loss.mean()
+
+ # Compute regular loss.
+ loss = torch.mean(
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
+ 1,
+ )
+ loss = loss.mean()
+
+ if args.with_prior_preservation:
+ # Add the prior loss to the instance loss.
+ loss = loss + args.prior_loss_weight * prior_loss
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = (
+ itertools.chain(transformer.parameters(), text_encoder_one.parameters())
+ if args.train_text_encoder
+ else transformer.parameters()
+ )
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad()
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ if accelerator.is_main_process:
+ if global_step % args.checkpointing_steps == 0:
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
+ if args.checkpoints_total_limit is not None:
+ checkpoints = os.listdir(args.output_dir)
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
+
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
+ if len(checkpoints) >= args.checkpoints_total_limit:
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
+ removing_checkpoints = checkpoints[0:num_to_remove]
+
+ logger.info(
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
+ )
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
+
+ for removing_checkpoint in removing_checkpoints:
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
+ shutil.rmtree(removing_checkpoint)
+
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
+ accelerator.save_state(save_path)
+ logger.info(f"Saved state to {save_path}")
+
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if accelerator.is_main_process:
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
+ # create pipeline
+ if not args.train_text_encoder:
+ text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
+ text_encoder_one.to(weight_dtype)
+ text_encoder_two.to(weight_dtype)
+ pipeline = FluxKontextPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ vae=vae,
+ text_encoder=unwrap_model(text_encoder_one),
+ text_encoder_2=unwrap_model(text_encoder_two),
+ transformer=unwrap_model(transformer),
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ torch_dtype=weight_dtype,
+ )
+ if not args.train_text_encoder:
+ del text_encoder_one, text_encoder_two
+ free_memory()
+
+ images = None
+ free_memory()
+
+ # Save the lora layers
+ accelerator.wait_for_everyone()
+ if accelerator.is_main_process:
+ modules_to_save = {}
+ transformer = unwrap_model(transformer)
+ if args.upcast_before_saving:
+ transformer.to(torch.float32)
+ else:
+ transformer = transformer.to(weight_dtype)
+ transformer_lora_layers = get_peft_model_state_dict(transformer)
+ modules_to_save["transformer"] = transformer
+
+ if args.train_text_encoder:
+ text_encoder_one = unwrap_model(text_encoder_one)
+ text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
+ modules_to_save["text_encoder"] = text_encoder_one
+ else:
+ text_encoder_lora_layers = None
+
+ FluxKontextPipeline.save_lora_weights(
+ save_directory=args.output_dir,
+ transformer_lora_layers=transformer_lora_layers,
+ text_encoder_lora_layers=text_encoder_lora_layers,
+ **_collate_lora_metadata(modules_to_save),
+ )
+
+ # Final inference
+ # Load previous pipeline
+ transformer = FluxTransformer2DModel.from_pretrained(
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
+ )
+ pipeline = FluxKontextPipeline.from_pretrained(
+ args.pretrained_model_name_or_path,
+ transformer=transformer,
+ revision=args.revision,
+ variant=args.variant,
+ torch_dtype=weight_dtype,
+ )
+ # load attention processors
+ pipeline.load_lora_weights(args.output_dir)
+
+ # run inference
+ images = []
+ if args.validation_prompt and args.num_validation_images > 0:
+ pipeline_args = {"prompt": args.validation_prompt}
+ if has_image_input and args.validation_image:
+ pipeline_args.update({"image": load_image(args.validation_image)})
+ images = log_validation(
+ pipeline=pipeline,
+ args=args,
+ accelerator=accelerator,
+ pipeline_args=pipeline_args,
+ epoch=epoch,
+ is_final_validation=True,
+ torch_dtype=weight_dtype,
+ )
+ del pipeline
+ free_memory()
+
+ if args.push_to_hub:
+ save_model_card(
+ repo_id,
+ images=images,
+ base_model=args.pretrained_model_name_or_path,
+ train_text_encoder=args.train_text_encoder,
+ instance_prompt=args.instance_prompt,
+ validation_prompt=args.validation_prompt,
+ repo_folder=args.output_dir,
+ )
+ upload_folder(
+ repo_id=repo_id,
+ folder_path=args.output_dir,
+ commit_message="End of training",
+ ignore_patterns=["step_*", "epoch_*"],
+ )
+
+ images = None
+
+ accelerator.end_training()
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)
diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py
index a1337e8dba..199a8a68ea 100644
--- a/examples/dreambooth/train_dreambooth_lora_hidream.py
+++ b/examples/dreambooth/train_dreambooth_lora_hidream.py
@@ -58,6 +58,7 @@ from diffusers.training_utils import (
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
free_memory,
+ offload_models,
)
from diffusers.utils import (
check_min_version,
@@ -73,7 +74,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.33.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -935,7 +936,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
@@ -1364,43 +1365,34 @@ def main(args):
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
# the redundant encoding.
if not train_dataset.custom_instance_prompts:
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
- (
- instance_prompt_hidden_states_t5,
- instance_prompt_hidden_states_llama3,
- instance_pooled_prompt_embeds,
- _,
- _,
- _,
- ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ (
+ instance_prompt_hidden_states_t5,
+ instance_prompt_hidden_states_llama3,
+ instance_pooled_prompt_embeds,
+ _,
+ _,
+ _,
+ ) = compute_text_embeddings(args.instance_prompt, text_encoding_pipeline)
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
- (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
- compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
- )
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ (class_prompt_hidden_states_t5, class_prompt_hidden_states_llama3, class_pooled_prompt_embeds, _, _, _) = (
+ compute_text_embeddings(args.class_prompt, text_encoding_pipeline)
+ )
validation_embeddings = {}
if args.validation_prompt is not None:
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
- (
- validation_embeddings["prompt_embeds_t5"],
- validation_embeddings["prompt_embeds_llama3"],
- validation_embeddings["pooled_prompt_embeds"],
- validation_embeddings["negative_prompt_embeds_t5"],
- validation_embeddings["negative_prompt_embeds_llama3"],
- validation_embeddings["negative_pooled_prompt_embeds"],
- ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
- if args.offload:
- text_encoding_pipeline = text_encoding_pipeline.to("cpu")
+ with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
+ (
+ validation_embeddings["prompt_embeds_t5"],
+ validation_embeddings["prompt_embeds_llama3"],
+ validation_embeddings["pooled_prompt_embeds"],
+ validation_embeddings["negative_prompt_embeds_t5"],
+ validation_embeddings["negative_prompt_embeds_llama3"],
+ validation_embeddings["negative_pooled_prompt_embeds"],
+ ) = compute_text_embeddings(args.validation_prompt, text_encoding_pipeline)
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1581,12 +1573,10 @@ def main(args):
if args.cache_latents:
model_input = latents_cache[step].sample()
else:
- if args.offload:
- vae = vae.to(accelerator.device)
- pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
+ with offload_models(vae, device=accelerator.device, offload=args.offload):
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
model_input = vae.encode(pixel_values).latent_dist.sample()
- if args.offload:
- vae = vae.to("cpu")
+
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
model_input = model_input.to(dtype=weight_dtype)
diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py
index da499bce71..ee84de66d2 100644
--- a/examples/dreambooth/train_dreambooth_lora_lumina2.py
+++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -859,7 +859,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py
index c156523db3..2c4e63fd95 100644
--- a/examples/dreambooth/train_dreambooth_lora_sana.py
+++ b/examples/dreambooth/train_dreambooth_lora_sana.py
@@ -13,6 +13,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
+# /// script
+# dependencies = [
+# "diffusers @ git+https://github.com/huggingface/diffusers.git",
+# "torch>=2.0.0",
+# "accelerate>=1.0.0",
+# "transformers>=4.47.0",
+# "ftfy",
+# "tensorboard",
+# "Jinja2",
+# "peft>=0.14.0",
+# "sentencepiece",
+# ]
+# ///
+
import argparse
import copy
import itertools
@@ -72,7 +86,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -852,7 +866,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py
index 05dfe6301f..5ab21df518 100644
--- a/examples/dreambooth/train_dreambooth_lora_sd3.py
+++ b/examples/dreambooth/train_dreambooth_lora_sd3.py
@@ -72,7 +72,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -1063,7 +1063,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py
index c3dfc923f0..5758db8508 100644
--- a/examples/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py
@@ -79,7 +79,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -983,7 +983,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py
index 8d5dee0188..b130b9ff21 100644
--- a/examples/dreambooth/train_dreambooth_sd3.py
+++ b/examples/dreambooth/train_dreambooth_sd3.py
@@ -63,7 +63,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -988,7 +988,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/flux-control/README.md b/examples/flux-control/README.md
index 14afa499db..5463fc1552 100644
--- a/examples/flux-control/README.md
+++ b/examples/flux-control/README.md
@@ -13,7 +13,7 @@ To incorporate additional condition latents, we expand the input features of Flu
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py
index 3be0182f6d..63cb770ccd 100644
--- a/examples/flux-control/train_control_flux.py
+++ b/examples/flux-control/train_control_flux.py
@@ -54,7 +54,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -697,7 +697,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_out_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py
index 34755209ce..2990d5701a 100644
--- a/examples/flux-control/train_control_lora_flux.py
+++ b/examples/flux-control/train_control_lora_flux.py
@@ -57,7 +57,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -725,7 +725,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.use_lora_bias and args.gaussian_init_lora:
raise ValueError("`gaussian` LoRA init scheme isn't supported when `use_lora_bias` is True.")
@@ -837,11 +837,6 @@ def main(args):
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
- if args.train_norm_layers:
- for name, param in flux_transformer.named_parameters():
- if any(k in name for k in NORM_LAYER_PREFIXES):
- param.requires_grad = True
-
if args.lora_layers is not None:
if args.lora_layers != "all-linear":
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
@@ -879,6 +874,11 @@ def main(args):
)
flux_transformer.add_adapter(transformer_lora_config)
+ if args.train_norm_layers:
+ for name, param in flux_transformer.named_parameters():
+ if any(k in name for k in NORM_LAYER_PREFIXES):
+ param.requires_grad = True
+
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix.py b/examples/instruct_pix2pix/train_instruct_pix2pix.py
index 9f536139ab..b6b29fce27 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix.py
@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -430,7 +430,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
index 1b01f61738..ef55321f58 100644
--- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
+++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -483,7 +483,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/kandinsky2_2/text_to_image/README.md b/examples/kandinsky2_2/text_to_image/README.md
index c14e02f6d0..c6afca8689 100644
--- a/examples/kandinsky2_2/text_to_image/README.md
+++ b/examples/kandinsky2_2/text_to_image/README.md
@@ -41,7 +41,7 @@ For all our examples, we will directly store the trained weights on the Hub, so
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
We also use [Weights and Biases](https://docs.wandb.ai/quickstart) logging by default, because it is really useful to monitor the training progress by regularly generating sample images during training. To install wandb, run
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
index f2c5047d75..56a8136ab2 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
@@ -52,7 +52,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -444,7 +444,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
index 5b39c25901..7461f5b742 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -330,7 +330,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
index 8c31f8f03b..64fd8ba3cb 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py
@@ -46,7 +46,7 @@ from diffusers.utils import check_min_version, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -342,7 +342,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
index 1f16c2d21a..fd4694d862 100644
--- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
+++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py
@@ -51,7 +51,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -445,7 +445,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/model_search/pipeline_easy.py b/examples/model_search/pipeline_easy.py
index b82e98fb71..fcce297c37 100644
--- a/examples/model_search/pipeline_easy.py
+++ b/examples/model_search/pipeline_easy.py
@@ -1249,7 +1249,7 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login`.
@@ -1358,7 +1358,7 @@ class EasyPipelineForText2Image(AutoPipelineForText2Image):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login`.
@@ -1507,7 +1507,7 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login`.
@@ -1617,7 +1617,7 @@ class EasyPipelineForImage2Image(AutoPipelineForImage2Image):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login`.
@@ -1766,7 +1766,7 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login
@@ -1875,7 +1875,7 @@ class EasyPipelineForInpainting(AutoPipelineForInpainting):
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ `hf auth login
diff --git a/examples/research_projects/autoencoderkl/train_autoencoderkl.py b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
index 31cf8414ac..dfb9e42ef1 100644
--- a/examples/research_projects/autoencoderkl/train_autoencoderkl.py
+++ b/examples/research_projects/autoencoderkl/train_autoencoderkl.py
@@ -568,7 +568,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
index c873356eb2..5cca8eea89 100644
--- a/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
+++ b/examples/research_projects/consistency_training/train_cm_ct_unconditional.py
@@ -789,7 +789,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
diff --git a/examples/research_projects/controlnet/train_controlnet_webdataset.py b/examples/research_projects/controlnet/train_controlnet_webdataset.py
index 9744bc7be2..f33a65c756 100644
--- a/examples/research_projects/controlnet/train_controlnet_webdataset.py
+++ b/examples/research_projects/controlnet/train_controlnet_webdataset.py
@@ -899,7 +899,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
index 8ea0768604..fda2a15809 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo.py
@@ -470,7 +470,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
index d11f961def..aa39b0b517 100644
--- a/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
+++ b/examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py
@@ -512,7 +512,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
index 12eb67d4a7..46045d330b 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora.py
@@ -502,7 +502,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
index 9f96ef944a..93418bf910 100644
--- a/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
+++ b/examples/research_projects/diffusion_orpo/train_diffusion_orpo_sdxl_lora_wds.py
@@ -609,7 +609,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/flux_lora_quantization/README.md b/examples/research_projects/flux_lora_quantization/README.md
index c0d76ac9bc..71ed28520a 100644
--- a/examples/research_projects/flux_lora_quantization/README.md
+++ b/examples/research_projects/flux_lora_quantization/README.md
@@ -39,7 +39,7 @@ python compute_embeddings.py
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
```bash
-huggingface-cli
+hf auth login
```
Then launch:
diff --git a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
index ca61664059..572c69fddf 100644
--- a/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
+++ b/examples/research_projects/flux_lora_quantization/train_dreambooth_lora_flux_miniature.py
@@ -587,7 +587,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/research_projects/gligen/README.md b/examples/research_projects/gligen/README.md
index fa922617d9..3da23306ce 100644
--- a/examples/research_projects/gligen/README.md
+++ b/examples/research_projects/gligen/README.md
@@ -47,11 +47,11 @@ pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
Download the pre-trained model:
```bash
-huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
-huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
-huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
-huggingface-cli download --resume-download clip-vit-large-patch14
-huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
+hf download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
+hf download --resume-download IDEA-Research/grounding-dino-base
+hf download --resume-download Salesforce/blip2-flan-t5-xxl
+hf download --resume-download clip-vit-large-patch14
+hf download --resume-download masterful/gligen-1-4-generation-text-box
```
Make the training data on 8 GPUs:
@@ -66,7 +66,7 @@ torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
You can download the COCO training data from
```bash
-huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
+hf download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
```
It's in the format of
@@ -125,7 +125,7 @@ Note that although the pre-trained GLIGEN model has been loaded, the parameters
The trained model can be downloaded from
```bash
-huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
+hf download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
```
You can run `demo.ipynb` to visualize the generated images.
diff --git a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
index ac754dc9c5..06079fe9ed 100644
--- a/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
+++ b/examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py
@@ -488,7 +488,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
index ea4a0d255b..740a759420 100644
--- a/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
+++ b/examples/research_projects/intel_opts/textual_inversion/textual_inversion_bf16.py
@@ -366,7 +366,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/lora/README.md b/examples/research_projects/lora/README.md
index 589d3e9c0f..55b870b0bc 100644
--- a/examples/research_projects/lora/README.md
+++ b/examples/research_projects/lora/README.md
@@ -34,7 +34,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py
index a734c50d8e..a9079c114f 100644
--- a/examples/research_projects/lora/train_text_to_image_lora.py
+++ b/examples/research_projects/lora/train_text_to_image_lora.py
@@ -396,7 +396,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
index 57c555e43f..6b0ae5ba97 100644
--- a/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
+++ b/examples/research_projects/multi_subject_dreambooth/train_multi_subject_dreambooth.py
@@ -684,7 +684,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/multi_token_textual_inversion/README.md b/examples/research_projects/multi_token_textual_inversion/README.md
index 16847c2cce..7d80c0beee 100644
--- a/examples/research_projects/multi_token_textual_inversion/README.md
+++ b/examples/research_projects/multi_token_textual_inversion/README.md
@@ -60,7 +60,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
index 75dcfccbd5..ffcc8a75c8 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion.py
@@ -551,7 +551,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py b/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
index ecc89f9829..a5973e1490 100644
--- a/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
+++ b/examples/research_projects/multi_token_textual_inversion/textual_inversion_flax.py
@@ -153,7 +153,7 @@ def parse_args():
"--use_auth_token",
action="store_true",
help=(
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ "Will use the token generated when running `hf auth login` (necessary to use this script with"
" private models)."
),
)
diff --git a/examples/research_projects/onnxruntime/text_to_image/README.md b/examples/research_projects/onnxruntime/text_to_image/README.md
index f1f134c576..f398f08166 100644
--- a/examples/research_projects/onnxruntime/text_to_image/README.md
+++ b/examples/research_projects/onnxruntime/text_to_image/README.md
@@ -41,7 +41,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
index ef910fab40..dd4c341ca8 100644
--- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py
@@ -415,7 +415,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/research_projects/onnxruntime/textual_inversion/README.md b/examples/research_projects/onnxruntime/textual_inversion/README.md
index a0ca4f954b..fa6d95af30 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/README.md
+++ b/examples/research_projects/onnxruntime/textual_inversion/README.md
@@ -46,7 +46,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
index a881b06a94..28bf029af4 100644
--- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
+++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py
@@ -566,7 +566,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
index 9a00f7cc4a..acbb77fe3a 100644
--- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
+++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py
@@ -280,7 +280,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/pixart/train_pixart_controlnet_hf.py b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
index ec954505c2..e2f1fa1bc5 100644
--- a/examples/research_projects/pixart/train_pixart_controlnet_hf.py
+++ b/examples/research_projects/pixart/train_pixart_controlnet_hf.py
@@ -562,7 +562,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/pytorch_xla/inference/flux/README.md b/examples/research_projects/pytorch_xla/inference/flux/README.md
index 9d482e6805..0bbd650bb6 100644
--- a/examples/research_projects/pytorch_xla/inference/flux/README.md
+++ b/examples/research_projects/pytorch_xla/inference/flux/README.md
@@ -40,7 +40,7 @@ cd examples/research_projects/pytorch_xla/inference/flux/
As the model is gated, before using it with diffusers you first need to go to the [FLUX.1 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.1-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
Then run:
diff --git a/examples/research_projects/pytorch_xla/training/text_to_image/README.md b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
index 06013b8a61..f99ab12486 100644
--- a/examples/research_projects/pytorch_xla/training/text_to_image/README.md
+++ b/examples/research_projects/pytorch_xla/training/text_to_image/README.md
@@ -80,7 +80,7 @@ pip3 install .'
Run the following command to authenticate your token.
```bash
-huggingface-cli login
+hf auth login
```
This script only trains the unet part of the network. The VAE and text encoder
diff --git a/examples/research_projects/realfill/train_realfill.py b/examples/research_projects/realfill/train_realfill.py
index 419636d131..fd63f71b5f 100644
--- a/examples/research_projects/realfill/train_realfill.py
+++ b/examples/research_projects/realfill/train_realfill.py
@@ -535,7 +535,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/sana/README.md b/examples/research_projects/sana/README.md
index ae80d11df4..933f32e3f9 100644
--- a/examples/research_projects/sana/README.md
+++ b/examples/research_projects/sana/README.md
@@ -19,7 +19,7 @@ mkdir -p $your_local_path # Create the directory if it doesn't exist
Download the SANA Sprint teacher model from Hugging Face Hub. The script uses the 1.6B parameter model.
```bash
-huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
+hf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
```
*(Optional: You can also download the 0.6B model by replacing the model name: `Efficient-Large-Model/Sana_Sprint_0.6B_1024px_teacher_diffusers`)*
diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py
index 335d9c377c..51db15f194 100644
--- a/examples/research_projects/sana/train_sana_sprint_diffusers.py
+++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py
@@ -940,7 +940,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.sh b/examples/research_projects/sana/train_sana_sprint_diffusers.sh
index 301fe5e429..acd49ad67f 100644
--- a/examples/research_projects/sana/train_sana_sprint_diffusers.sh
+++ b/examples/research_projects/sana/train_sana_sprint_diffusers.sh
@@ -1,6 +1,6 @@
your_local_path='output'
-huggingface-cli download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
+hf download Efficient-Large-Model/SANA_Sprint_1.6B_1024px_teacher_diffusers --local-dir $your_local_path/SANA_Sprint_1.6B_1024px_teacher_diffusers
# or Sana_Sprint_0.6B_1024px_teacher_diffusers
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
index fd5b83a66e..50ab487bfe 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth.py
@@ -854,7 +854,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
index 393f991387..5ce510861a 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora.py
@@ -782,7 +782,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
index f011871c25..554aaedd7b 100644
--- a/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/dreambooth/train_dreambooth_lora_sdxl.py
@@ -1054,7 +1054,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.do_edm_style_training and args.snr_gamma is not None:
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
index d867a5dd6a..c92b0ac053 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image.py
@@ -547,7 +547,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
index d01d5838f2..b7aa7b7bbb 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora.py
@@ -442,7 +442,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
index d9efca5ba5..715852cb72 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_lora_sdxl.py
@@ -537,7 +537,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
index 88880f5669..5a26fd3074 100644
--- a/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/research_projects/scheduled_huber_loss_training/text_to_image/train_text_to_image_sdxl.py
@@ -630,7 +630,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/sd3_lora_colab/README.md b/examples/research_projects/sd3_lora_colab/README.md
index 33fc7030de..be1bddf983 100644
--- a/examples/research_projects/sd3_lora_colab/README.md
+++ b/examples/research_projects/sd3_lora_colab/README.md
@@ -6,7 +6,7 @@ This is an **EDUCATIONAL** project that provides utilities for DreamBooth LoRA t
> SD3 is gated, so you need to make sure you agree to [share your contact info](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) to access the model before using it with Diffusers. Once you have access, you need to log in so your system knows you’re authorized. Use the command below to log in:
```bash
-huggingface-cli login
+hf auth login
```
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
diff --git a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
index 8e8190a593..79c3169b63 100644
--- a/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
+++ b/examples/research_projects/sd3_lora_colab/sd3_dreambooth_lora_16gb.ipynb
@@ -60,7 +60,7 @@
},
"outputs": [],
"source": [
- "!huggingface-cli login"
+ "!hf auth login"
]
},
{
@@ -2425,4 +2425,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
index 21eb57ddc2..d73aab7363 100644
--- a/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
+++ b/examples/research_projects/sd3_lora_colab/train_dreambooth_lora_sd3_miniature.py
@@ -623,7 +623,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
diff --git a/examples/research_projects/wuerstchen/text_to_image/README.md b/examples/research_projects/wuerstchen/text_to_image/README.md
index 118c5e0cf9..8df068a873 100644
--- a/examples/research_projects/wuerstchen/text_to_image/README.md
+++ b/examples/research_projects/wuerstchen/text_to_image/README.md
@@ -26,7 +26,7 @@ accelerate config
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
```bash
-huggingface-cli login
+hf auth login
```
## Prior training
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
index 9e2302f1b1..12586b5f57 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_lora_prior.py
@@ -446,7 +446,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
index 83647097d2..e72152b45c 100644
--- a/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
+++ b/examples/research_projects/wuerstchen/text_to_image/train_text_to_image_prior.py
@@ -444,7 +444,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/server/requirements.in b/examples/server/requirements.in
index b49b285a8f..a469569a10 100644
--- a/examples/server/requirements.in
+++ b/examples/server/requirements.in
@@ -1,4 +1,4 @@
-torch~=2.4.0
+torch~=2.7.0
transformers==4.46.1
sentencepiece
aiohttp
diff --git a/examples/server/requirements.txt b/examples/server/requirements.txt
index 5d840811f8..b91a8861a0 100644
--- a/examples/server/requirements.txt
+++ b/examples/server/requirements.txt
@@ -1,10 +1,10 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements.in -o requirements.txt
-aiohappyeyeballs==2.4.3
+aiohappyeyeballs==2.6.1
# via aiohttp
-aiohttp==3.10.10
+aiohttp==3.12.14
# via -r requirements.in
-aiosignal==1.3.1
+aiosignal==1.4.0
# via aiohttp
annotated-types==0.7.0
# via pydantic
@@ -29,7 +29,6 @@ filelock==3.16.1
# huggingface-hub
# torch
# transformers
- # triton
frozenlist==1.5.0
# via
# aiohttp
@@ -63,36 +62,42 @@ networkx==3.2.1
# via torch
numpy==2.0.2
# via transformers
-nvidia-cublas-cu12==12.1.3.1
+nvidia-cublas-cu12==12.6.4.1
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
-nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-cupti-cu12==12.6.80
# via torch
-nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.6.77
# via torch
-nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.6.77
# via torch
-nvidia-cudnn-cu12==9.1.0.70
+nvidia-cudnn-cu12==9.5.1.17
# via torch
-nvidia-cufft-cu12==11.0.2.54
+nvidia-cufft-cu12==11.3.0.4
# via torch
-nvidia-curand-cu12==10.3.2.106
+nvidia-cufile-cu12==1.11.1.6
# via torch
-nvidia-cusolver-cu12==11.4.5.107
+nvidia-curand-cu12==10.3.7.77
# via torch
-nvidia-cusparse-cu12==12.1.0.106
+nvidia-cusolver-cu12==11.7.1.2
+ # via torch
+nvidia-cusparse-cu12==12.5.4.2
# via
# nvidia-cusolver-cu12
# torch
-nvidia-nccl-cu12==2.20.5
+nvidia-cusparselt-cu12==0.6.3
# via torch
-nvidia-nvjitlink-cu12==12.9.86
+nvidia-nccl-cu12==2.26.2
+ # via torch
+nvidia-nvjitlink-cu12==12.6.85
# via
+ # nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
-nvidia-nvtx-cu12==12.1.105
+ # torch
+nvidia-nvtx-cu12==12.6.77
# via torch
packaging==24.1
# via
@@ -105,7 +110,9 @@ prometheus-client==0.21.0
prometheus-fastapi-instrumentator==7.0.0
# via -r requirements.in
propcache==0.2.0
- # via yarl
+ # via
+ # aiohttp
+ # yarl
py-consul==1.5.3
# via -r requirements.in
pydantic==2.9.2
@@ -137,7 +144,7 @@ sympy==1.13.3
# via torch
tokenizers==0.20.1
# via transformers
-torch==2.4.1
+torch==2.7.0
# via -r requirements.in
tqdm==4.66.5
# via
@@ -145,10 +152,11 @@ tqdm==4.66.5
# transformers
transformers==4.46.1
# via -r requirements.in
-triton==3.0.0
+triton==3.3.0
# via torch
typing-extensions==4.12.2
# via
+ # aiosignal
# anyio
# exceptiongroup
# fastapi
@@ -163,5 +171,5 @@ urllib3==2.5.0
# via requests
uvicorn==0.32.0
# via -r requirements.in
-yarl==1.16.0
+yarl==1.18.3
# via aiohttp
diff --git a/examples/t2i_adapter/README_sdxl.md b/examples/t2i_adapter/README_sdxl.md
index 1e5a19feda..0a3b5e33d4 100644
--- a/examples/t2i_adapter/README_sdxl.md
+++ b/examples/t2i_adapter/README_sdxl.md
@@ -58,7 +58,7 @@ wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/ma
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```
-Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
+Then run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained T2IAdapter parameters to Hugging Face Hub.
```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
diff --git a/examples/t2i_adapter/train_t2i_adapter_sdxl.py b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
index cb8fade444..acbee19fa5 100644
--- a/examples/t2i_adapter/train_t2i_adapter_sdxl.py
+++ b/examples/t2i_adapter/train_t2i_adapter_sdxl.py
@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -783,7 +783,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md
index 940d40c7b2..ebbf0a96be 100644
--- a/examples/text_to_image/README.md
+++ b/examples/text_to_image/README.md
@@ -43,7 +43,7 @@ You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need
Run the following command to authenticate your token
```bash
-huggingface-cli login
+hf auth login
```
If you have already cloned the repo, then you won't need to go through these steps.
@@ -215,7 +215,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/text_to_image/README_sdxl.md b/examples/text_to_image/README_sdxl.md
index c0b7840f10..6fb10ec9e1 100644
--- a/examples/text_to_image/README_sdxl.md
+++ b/examples/text_to_image/README_sdxl.md
@@ -156,7 +156,7 @@ For this example we want to directly store the trained LoRA embeddings on the Hu
we need to be logged in and add the `--push_to_hub` flag.
```bash
-huggingface-cli login
+hf auth login
```
Now we can start training!
diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py
index 17f5dc852b..bbd8fc062e 100644
--- a/examples/text_to_image/train_text_to_image.py
+++ b/examples/text_to_image/train_text_to_image.py
@@ -57,7 +57,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -531,7 +531,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
if args.non_ema_revision is not None:
diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py
index d9c1aafe80..74423dcf27 100644
--- a/examples/text_to_image/train_text_to_image_flax.py
+++ b/examples/text_to_image/train_text_to_image_flax.py
@@ -49,7 +49,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = logging.getLogger(__name__)
@@ -264,7 +264,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging.basicConfig(
diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py
index 89f867b5ba..19968c2547 100644
--- a/examples/text_to_image/train_text_to_image_lora.py
+++ b/examples/text_to_image/train_text_to_image_lora.py
@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -450,7 +450,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py
index 12afb72b9a..88be919727 100644
--- a/examples/text_to_image/train_text_to_image_lora_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py
@@ -68,7 +68,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -555,7 +555,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py
index 65a6131e66..dec202fbbf 100644
--- a/examples/text_to_image/train_text_to_image_sdxl.py
+++ b/examples/text_to_image/train_text_to_image_sdxl.py
@@ -55,7 +55,7 @@ from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
if is_torch_npu_available():
@@ -601,7 +601,7 @@ def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md
index 2f79107edb..06e22dbcd8 100644
--- a/examples/textual_inversion/README.md
+++ b/examples/textual_inversion/README.md
@@ -41,7 +41,7 @@ accelerate config
First, let's login so that we can upload the checkpoint to the Hub during training:
```bash
-huggingface-cli login
+hf auth login
```
Now let's get our dataset. For this example we will use some cat images: https://huggingface.co/datasets/diffusers/cat_toy_example .
diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py
index 6dcc2ff7dc..e31ba9bd0c 100644
--- a/examples/textual_inversion/textual_inversion.py
+++ b/examples/textual_inversion/textual_inversion.py
@@ -81,7 +81,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -594,7 +594,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py
index 44c46995a1..f5863d94b0 100644
--- a/examples/textual_inversion/textual_inversion_flax.py
+++ b/examples/textual_inversion/textual_inversion_flax.py
@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = logging.getLogger(__name__)
@@ -166,7 +166,7 @@ def parse_args():
"--use_auth_token",
action="store_true",
help=(
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script with"
+ "Will use the token generated when running `hf auth login` (necessary to use this script with"
" private models)."
),
)
diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py
index add15a8583..1752bfd3b1 100644
--- a/examples/textual_inversion/textual_inversion_sdxl.py
+++ b/examples/textual_inversion/textual_inversion_sdxl.py
@@ -76,7 +76,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__)
@@ -593,7 +593,7 @@ def main():
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
- " Please use `huggingface-cli login` to authenticate with the Hub."
+ " Please use `hf auth login` to authenticate with the Hub."
)
logging_dir = os.path.join(args.output_dir, args.logging_dir)
diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md
index 2990b3abf3..22f982509b 100644
--- a/examples/unconditional_image_generation/README.md
+++ b/examples/unconditional_image_generation/README.md
@@ -151,7 +151,7 @@ dataset = load_dataset("imagefolder", data_files={"train": ["path/to/file1", "pa
Next, push it to the hub!
```python
-# assuming you have ran the huggingface-cli login command in a terminal
+# assuming you have ran the hf auth login command in a terminal
dataset.push_to_hub("name_of_your_dataset")
# if you want to push to a private repo, simply pass private=True:
diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py
index baf2a9d899..892c674575 100644
--- a/examples/unconditional_image_generation/train_unconditional.py
+++ b/examples/unconditional_image_generation/train_unconditional.py
@@ -29,7 +29,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/examples/vqgan/train_vqgan.py b/examples/vqgan/train_vqgan.py
index a14ca13495..5ba1678d44 100644
--- a/examples/vqgan/train_vqgan.py
+++ b/examples/vqgan/train_vqgan.py
@@ -50,7 +50,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
-check_min_version("0.34.0.dev0")
+check_min_version("0.35.0.dev0")
logger = get_logger(__name__, log_level="INFO")
diff --git a/scripts/convert_cosmos_to_diffusers.py b/scripts/convert_cosmos_to_diffusers.py
index 0c0426a1ef..6f6563ad64 100644
--- a/scripts/convert_cosmos_to_diffusers.py
+++ b/scripts/convert_cosmos_to_diffusers.py
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
"mlp.layer1": "ff.net.0.proj",
"mlp.layer2": "ff.net.2",
"x_embedder.proj.1": "patch_embed.proj",
- # "extra_pos_embedder": "learnable_pos_embed",
"final_layer.adaln_modulation.1": "norm_out.linear_1",
"final_layer.adaln_modulation.2": "norm_out.linear_2",
"final_layer.linear": "proj_out",
diff --git a/scripts/convert_skyreelsv2_to_diffusers.py b/scripts/convert_skyreelsv2_to_diffusers.py
new file mode 100644
index 0000000000..3bc3c43568
--- /dev/null
+++ b/scripts/convert_skyreelsv2_to_diffusers.py
@@ -0,0 +1,637 @@
+import argparse
+import os
+import pathlib
+from typing import Any, Dict
+
+import torch
+from accelerate import init_empty_weights
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file
+from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+
+
+TRANSFORMER_KEYS_RENAME_DICT = {
+ "time_embedding.0": "condition_embedder.time_embedder.linear_1",
+ "time_embedding.2": "condition_embedder.time_embedder.linear_2",
+ "text_embedding.0": "condition_embedder.text_embedder.linear_1",
+ "text_embedding.2": "condition_embedder.text_embedder.linear_2",
+ "time_projection.1": "condition_embedder.time_proj",
+ "head.modulation": "scale_shift_table",
+ "head.head": "proj_out",
+ "modulation": "scale_shift_table",
+ "ffn.0": "ffn.net.0.proj",
+ "ffn.2": "ffn.net.2",
+ "fps_projection.0": "fps_projection.net.0.proj",
+ "fps_projection.2": "fps_projection.net.2",
+ # Hack to swap the layer names
+ # The original model calls the norms in following order: norm1, norm3, norm2
+ # We convert it to: norm1, norm2, norm3
+ "norm2": "norm__placeholder",
+ "norm3": "norm2",
+ "norm__placeholder": "norm3",
+ # For the I2V model
+ "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
+ "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
+ "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
+ "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # for the FLF2V model
+ "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
+ # Add attention component mappings
+ "self_attn.q": "attn1.to_q",
+ "self_attn.k": "attn1.to_k",
+ "self_attn.v": "attn1.to_v",
+ "self_attn.o": "attn1.to_out.0",
+ "self_attn.norm_q": "attn1.norm_q",
+ "self_attn.norm_k": "attn1.norm_k",
+ "cross_attn.q": "attn2.to_q",
+ "cross_attn.k": "attn2.to_k",
+ "cross_attn.v": "attn2.to_v",
+ "cross_attn.o": "attn2.to_out.0",
+ "cross_attn.norm_q": "attn2.norm_q",
+ "cross_attn.norm_k": "attn2.norm_k",
+ "attn2.to_k_img": "attn2.add_k_proj",
+ "attn2.to_v_img": "attn2.add_v_proj",
+ "attn2.norm_k_img": "attn2.norm_added_k",
+}
+
+TRANSFORMER_SPECIAL_KEYS_REMAP = {}
+
+
+def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
+ state_dict[new_key] = state_dict.pop(old_key)
+
+
+def load_sharded_safetensors(dir: pathlib.Path):
+ if "720P" in str(dir):
+ file_paths = list(dir.glob("diffusion_pytorch_model*.safetensors"))
+ else:
+ file_paths = list(dir.glob("model*.safetensors"))
+ state_dict = {}
+ for path in file_paths:
+ state_dict.update(load_file(path))
+ return state_dict
+
+
+def get_transformer_config(model_type: str) -> Dict[str, Any]:
+ if model_type == "SkyReels-V2-DF-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 12,
+ "inject_sample_info": True,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-DF-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-DF-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-DF-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-T2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-T2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-T2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-T2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 1536,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 12,
+ "inject_sample_info": False,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-I2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-1.3B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-1.3B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 1536,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 8960,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 12,
+ "inject_sample_info": False,
+ "num_layers": 30,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-14B-540P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-540P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ elif model_type == "SkyReels-V2-FLF2V-14B-720P":
+ config = {
+ "model_id": "Skywork/SkyReels-V2-I2V-14B-720P",
+ "diffusers_config": {
+ "added_kv_proj_dim": 5120,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "inject_sample_info": False,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ "image_dim": 1280,
+ "pos_embed_seq_len": 514,
+ },
+ }
+ return config
+
+
+def convert_transformer(model_type: str):
+ config = get_transformer_config(model_type)
+ diffusers_config = config["diffusers_config"]
+ model_id = config["model_id"]
+
+ if "1.3B" in model_type:
+ original_state_dict = load_file(hf_hub_download(model_id, "model.safetensors"))
+ else:
+ os.makedirs(model_type, exist_ok=True)
+ model_dir = pathlib.Path(model_type)
+ if "720P" in model_type:
+ top_shard = 7 if "I2V" in model_type else 6
+ zeros = "0" * (4 if "I2V" or "T2V" in model_type else 3)
+ model_name = "diffusion_pytorch_model"
+ elif "540P" in model_type:
+ top_shard = 14 if "I2V" in model_type else 12
+ model_name = "model"
+
+ for i in range(1, top_shard + 1):
+ shard_path = f"{model_name}-{i:05d}-of-{zeros}{top_shard}.safetensors"
+ hf_hub_download(model_id, shard_path, local_dir=model_dir)
+ original_state_dict = load_sharded_safetensors(model_dir)
+
+ with init_empty_weights():
+ transformer = SkyReelsV2Transformer3DModel.from_config(diffusers_config)
+
+ for key in list(original_state_dict.keys()):
+ new_key = key[:]
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ update_state_dict_(original_state_dict, key, new_key)
+
+ for key in list(original_state_dict.keys()):
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, original_state_dict)
+
+ if "FLF2V" in model_type:
+ if (
+ hasattr(transformer.condition_embedder, "image_embedder")
+ and hasattr(transformer.condition_embedder.image_embedder, "pos_embed")
+ and transformer.condition_embedder.image_embedder.pos_embed is not None
+ ):
+ pos_embed_shape = transformer.condition_embedder.image_embedder.pos_embed.shape
+ original_state_dict["condition_embedder.image_embedder.pos_embed"] = torch.zeros(pos_embed_shape)
+
+ transformer.load_state_dict(original_state_dict, strict=True, assign=True)
+ return transformer
+
+
+def convert_vae():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.1-T2V-14B", "Wan2.1_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Convert to down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Convert residual block naming but keep the original structure
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Convert to up_blocks
+ parts = key.split(".")
+ block_idx = int(parts[2])
+
+ # Group residual blocks
+ if "residual" in key:
+ if block_idx in [0, 1, 2]:
+ new_block_idx = 0
+ resnet_idx = block_idx
+ elif block_idx in [4, 5, 6]:
+ new_block_idx = 1
+ resnet_idx = block_idx - 4
+ elif block_idx in [8, 9, 10]:
+ new_block_idx = 2
+ resnet_idx = block_idx - 8
+ elif block_idx in [12, 13, 14]:
+ new_block_idx = 3
+ resnet_idx = block_idx - 12
+ else:
+ # Keep as is for other blocks
+ new_state_dict[key] = value
+ continue
+
+ # Convert residual block naming
+ if ".residual.0.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm1.gamma"
+ elif ".residual.2.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.bias"
+ elif ".residual.2.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv1.weight"
+ elif ".residual.3.gamma" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.norm2.gamma"
+ elif ".residual.6.bias" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.bias"
+ elif ".residual.6.weight" in key:
+ new_key = f"decoder.up_blocks.{new_block_idx}.resnets.{resnet_idx}.conv2.weight"
+ else:
+ new_key = key
+
+ new_state_dict[new_key] = value
+
+ # Handle shortcut connections
+ elif ".shortcut." in key:
+ if block_idx == 4:
+ new_key = key.replace(".shortcut.", ".resnets.0.conv_shortcut.")
+ new_key = new_key.replace("decoder.upsamples.4", "decoder.up_blocks.1")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_key = new_key.replace(".shortcut.", ".conv_shortcut.")
+
+ new_state_dict[new_key] = value
+
+ # Handle upsamplers
+ elif ".resample." in key or ".time_conv." in key:
+ if block_idx == 3:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.0.upsamplers.0")
+ elif block_idx == 7:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.1.upsamplers.0")
+ elif block_idx == 11:
+ new_key = key.replace(f"decoder.upsamples.{block_idx}", "decoder.up_blocks.2.upsamplers.0")
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ new_state_dict[new_key] = value
+ else:
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan()
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model_type", type=str, default=None)
+ parser.add_argument("--output_path", type=str, required=True)
+ parser.add_argument("--dtype", default="fp32")
+ return parser.parse_args()
+
+
+DTYPE_MAPPING = {
+ "fp32": torch.float32,
+ "fp16": torch.float16,
+ "bf16": torch.bfloat16,
+}
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ transformer = None
+ dtype = DTYPE_MAPPING[args.dtype]
+
+ transformer = convert_transformer(args.model_type).to(dtype=dtype)
+ vae = convert_vae()
+ text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
+ tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
+ scheduler = UniPCMultistepScheduler(
+ prediction_type="flow_prediction",
+ num_train_timesteps=1000,
+ use_flow_sigmas=True,
+ )
+
+ if "I2V" in args.model_type or "FLF2V" in args.model_type:
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ image_processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+ pipe = SkyReelsV2ImageToVideoPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ image_processor=image_processor,
+ )
+ elif "T2V" in args.model_type:
+ pipe = SkyReelsV2Pipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+ elif "DF" in args.model_type:
+ pipe = SkyReelsV2DiffusionForcingPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ )
+
+ pipe.save_pretrained(
+ args.output_path,
+ safe_serialization=True,
+ max_shard_size="5GB",
+ # push_to_hub=True,
+ # repo_id=f"/{args.model_type}-Diffusers",
+ )
diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py
index 6d25cde071..599c90be57 100644
--- a/scripts/convert_wan_to_diffusers.py
+++ b/scripts/convert_wan_to_diffusers.py
@@ -278,16 +278,82 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-I2V-14B-720p":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-I2V-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 36,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-T2V-A14B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-T2V-A14B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 13824,
+ "freq_dim": 256,
+ "in_channels": 16,
+ "num_attention_heads": 40,
+ "num_layers": 40,
+ "out_channels": 16,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
+ elif model_type == "Wan2.2-TI2V-5B":
+ config = {
+ "model_id": "Wan-AI/Wan2.2-TI2V-5B",
+ "diffusers_config": {
+ "added_kv_proj_dim": None,
+ "attention_head_dim": 128,
+ "cross_attn_norm": True,
+ "eps": 1e-06,
+ "ffn_dim": 14336,
+ "freq_dim": 256,
+ "in_channels": 48,
+ "num_attention_heads": 24,
+ "num_layers": 30,
+ "out_channels": 48,
+ "patch_size": [1, 2, 2],
+ "qk_norm": "rms_norm_across_heads",
+ "text_dim": 4096,
+ },
+ }
+ RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
+ SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
-def convert_transformer(model_type: str):
+def convert_transformer(model_type: str, stage: str = None):
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
diffusers_config = config["diffusers_config"]
model_id = config["model_id"]
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
+ if stage is not None:
+ model_dir = model_dir / stage
+
original_state_dict = load_sharded_safetensors(model_dir)
with init_empty_weights():
@@ -515,6 +581,310 @@ def convert_vae():
return vae
+vae22_diffusers_config = {
+ "base_dim": 160,
+ "z_dim": 48,
+ "is_residual": True,
+ "in_channels": 12,
+ "out_channels": 12,
+ "decoder_base_dim": 256,
+ "scale_factor_temporal": 4,
+ "scale_factor_spatial": 16,
+ "patch_size": 2,
+ "latents_mean": [
+ -0.2289,
+ -0.0052,
+ -0.1323,
+ -0.2339,
+ -0.2799,
+ 0.0174,
+ 0.1838,
+ 0.1557,
+ -0.1382,
+ 0.0542,
+ 0.2813,
+ 0.0891,
+ 0.1570,
+ -0.0098,
+ 0.0375,
+ -0.1825,
+ -0.2246,
+ -0.1207,
+ -0.0698,
+ 0.5109,
+ 0.2665,
+ -0.2108,
+ -0.2158,
+ 0.2502,
+ -0.2055,
+ -0.0322,
+ 0.1109,
+ 0.1567,
+ -0.0729,
+ 0.0899,
+ -0.2799,
+ -0.1230,
+ -0.0313,
+ -0.1649,
+ 0.0117,
+ 0.0723,
+ -0.2839,
+ -0.2083,
+ -0.0520,
+ 0.3748,
+ 0.0152,
+ 0.1957,
+ 0.1433,
+ -0.2944,
+ 0.3573,
+ -0.0548,
+ -0.1681,
+ -0.0667,
+ ],
+ "latents_std": [
+ 0.4765,
+ 1.0364,
+ 0.4514,
+ 1.1677,
+ 0.5313,
+ 0.4990,
+ 0.4818,
+ 0.5013,
+ 0.8158,
+ 1.0344,
+ 0.5894,
+ 1.0901,
+ 0.6885,
+ 0.6165,
+ 0.8454,
+ 0.4978,
+ 0.5759,
+ 0.3523,
+ 0.7135,
+ 0.6804,
+ 0.5833,
+ 1.4146,
+ 0.8986,
+ 0.5659,
+ 0.7069,
+ 0.5338,
+ 0.4889,
+ 0.4917,
+ 0.4069,
+ 0.4999,
+ 0.6866,
+ 0.4093,
+ 0.5709,
+ 0.6065,
+ 0.6415,
+ 0.4944,
+ 0.5726,
+ 1.2042,
+ 0.5458,
+ 1.6887,
+ 0.3971,
+ 1.0600,
+ 0.3943,
+ 0.5537,
+ 0.5444,
+ 0.4089,
+ 0.7468,
+ 0.7744,
+ ],
+ "clip_output": False,
+}
+
+
+def convert_vae_22():
+ vae_ckpt_path = hf_hub_download("Wan-AI/Wan2.2-TI2V-5B", "Wan2.2_VAE.pth")
+ old_state_dict = torch.load(vae_ckpt_path, weights_only=True)
+ new_state_dict = {}
+
+ # Create mappings for specific components
+ middle_key_mapping = {
+ # Encoder middle block
+ "encoder.middle.0.residual.0.gamma": "encoder.mid_block.resnets.0.norm1.gamma",
+ "encoder.middle.0.residual.2.bias": "encoder.mid_block.resnets.0.conv1.bias",
+ "encoder.middle.0.residual.2.weight": "encoder.mid_block.resnets.0.conv1.weight",
+ "encoder.middle.0.residual.3.gamma": "encoder.mid_block.resnets.0.norm2.gamma",
+ "encoder.middle.0.residual.6.bias": "encoder.mid_block.resnets.0.conv2.bias",
+ "encoder.middle.0.residual.6.weight": "encoder.mid_block.resnets.0.conv2.weight",
+ "encoder.middle.2.residual.0.gamma": "encoder.mid_block.resnets.1.norm1.gamma",
+ "encoder.middle.2.residual.2.bias": "encoder.mid_block.resnets.1.conv1.bias",
+ "encoder.middle.2.residual.2.weight": "encoder.mid_block.resnets.1.conv1.weight",
+ "encoder.middle.2.residual.3.gamma": "encoder.mid_block.resnets.1.norm2.gamma",
+ "encoder.middle.2.residual.6.bias": "encoder.mid_block.resnets.1.conv2.bias",
+ "encoder.middle.2.residual.6.weight": "encoder.mid_block.resnets.1.conv2.weight",
+ # Decoder middle block
+ "decoder.middle.0.residual.0.gamma": "decoder.mid_block.resnets.0.norm1.gamma",
+ "decoder.middle.0.residual.2.bias": "decoder.mid_block.resnets.0.conv1.bias",
+ "decoder.middle.0.residual.2.weight": "decoder.mid_block.resnets.0.conv1.weight",
+ "decoder.middle.0.residual.3.gamma": "decoder.mid_block.resnets.0.norm2.gamma",
+ "decoder.middle.0.residual.6.bias": "decoder.mid_block.resnets.0.conv2.bias",
+ "decoder.middle.0.residual.6.weight": "decoder.mid_block.resnets.0.conv2.weight",
+ "decoder.middle.2.residual.0.gamma": "decoder.mid_block.resnets.1.norm1.gamma",
+ "decoder.middle.2.residual.2.bias": "decoder.mid_block.resnets.1.conv1.bias",
+ "decoder.middle.2.residual.2.weight": "decoder.mid_block.resnets.1.conv1.weight",
+ "decoder.middle.2.residual.3.gamma": "decoder.mid_block.resnets.1.norm2.gamma",
+ "decoder.middle.2.residual.6.bias": "decoder.mid_block.resnets.1.conv2.bias",
+ "decoder.middle.2.residual.6.weight": "decoder.mid_block.resnets.1.conv2.weight",
+ }
+
+ # Create a mapping for attention blocks
+ attention_mapping = {
+ # Encoder middle attention
+ "encoder.middle.1.norm.gamma": "encoder.mid_block.attentions.0.norm.gamma",
+ "encoder.middle.1.to_qkv.weight": "encoder.mid_block.attentions.0.to_qkv.weight",
+ "encoder.middle.1.to_qkv.bias": "encoder.mid_block.attentions.0.to_qkv.bias",
+ "encoder.middle.1.proj.weight": "encoder.mid_block.attentions.0.proj.weight",
+ "encoder.middle.1.proj.bias": "encoder.mid_block.attentions.0.proj.bias",
+ # Decoder middle attention
+ "decoder.middle.1.norm.gamma": "decoder.mid_block.attentions.0.norm.gamma",
+ "decoder.middle.1.to_qkv.weight": "decoder.mid_block.attentions.0.to_qkv.weight",
+ "decoder.middle.1.to_qkv.bias": "decoder.mid_block.attentions.0.to_qkv.bias",
+ "decoder.middle.1.proj.weight": "decoder.mid_block.attentions.0.proj.weight",
+ "decoder.middle.1.proj.bias": "decoder.mid_block.attentions.0.proj.bias",
+ }
+
+ # Create a mapping for the head components
+ head_mapping = {
+ # Encoder head
+ "encoder.head.0.gamma": "encoder.norm_out.gamma",
+ "encoder.head.2.bias": "encoder.conv_out.bias",
+ "encoder.head.2.weight": "encoder.conv_out.weight",
+ # Decoder head
+ "decoder.head.0.gamma": "decoder.norm_out.gamma",
+ "decoder.head.2.bias": "decoder.conv_out.bias",
+ "decoder.head.2.weight": "decoder.conv_out.weight",
+ }
+
+ # Create a mapping for the quant components
+ quant_mapping = {
+ "conv1.weight": "quant_conv.weight",
+ "conv1.bias": "quant_conv.bias",
+ "conv2.weight": "post_quant_conv.weight",
+ "conv2.bias": "post_quant_conv.bias",
+ }
+
+ # Process each key in the state dict
+ for key, value in old_state_dict.items():
+ # Handle middle block keys using the mapping
+ if key in middle_key_mapping:
+ new_key = middle_key_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle attention blocks using the mapping
+ elif key in attention_mapping:
+ new_key = attention_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle head keys using the mapping
+ elif key in head_mapping:
+ new_key = head_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle quant keys using the mapping
+ elif key in quant_mapping:
+ new_key = quant_mapping[key]
+ new_state_dict[new_key] = value
+ # Handle encoder conv1
+ elif key == "encoder.conv1.weight":
+ new_state_dict["encoder.conv_in.weight"] = value
+ elif key == "encoder.conv1.bias":
+ new_state_dict["encoder.conv_in.bias"] = value
+ # Handle decoder conv1
+ elif key == "decoder.conv1.weight":
+ new_state_dict["decoder.conv_in.weight"] = value
+ elif key == "decoder.conv1.bias":
+ new_state_dict["decoder.conv_in.bias"] = value
+ # Handle encoder downsamples
+ elif key.startswith("encoder.downsamples."):
+ # Change encoder.downsamples to encoder.down_blocks
+ new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
+
+ # Handle residual blocks - change downsamples to resnets and rename components
+ if "residual" in new_key or "shortcut" in new_key:
+ # Change the second downsamples to resnets
+ new_key = new_key.replace(".downsamples.", ".resnets.")
+
+ # Rename residual components
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+
+ # Handle resample blocks - change downsamples to downsampler and remove index
+ elif "resample" in new_key or "time_conv" in new_key:
+ # Change the second downsamples to downsampler and remove the index
+ parts = new_key.split(".")
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
+ if len(parts) >= 4 and parts[3] == "downsamples":
+ # Remove the index (parts[4]) and change downsamples to downsampler
+ new_parts = parts[:3] + ["downsampler"] + parts[5:]
+ new_key = ".".join(new_parts)
+
+ new_state_dict[new_key] = value
+
+ # Handle decoder upsamples
+ elif key.startswith("decoder.upsamples."):
+ # Change decoder.upsamples to decoder.up_blocks
+ new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
+
+ # Handle residual blocks - change upsamples to resnets and rename components
+ if "residual" in new_key or "shortcut" in new_key:
+ # Change the second upsamples to resnets
+ new_key = new_key.replace(".upsamples.", ".resnets.")
+
+ # Rename residual components
+ if ".residual.0.gamma" in new_key:
+ new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
+ elif ".residual.2.weight" in new_key:
+ new_key = new_key.replace(".residual.2.weight", ".conv1.weight")
+ elif ".residual.2.bias" in new_key:
+ new_key = new_key.replace(".residual.2.bias", ".conv1.bias")
+ elif ".residual.3.gamma" in new_key:
+ new_key = new_key.replace(".residual.3.gamma", ".norm2.gamma")
+ elif ".residual.6.weight" in new_key:
+ new_key = new_key.replace(".residual.6.weight", ".conv2.weight")
+ elif ".residual.6.bias" in new_key:
+ new_key = new_key.replace(".residual.6.bias", ".conv2.bias")
+ elif ".shortcut.weight" in new_key:
+ new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
+ elif ".shortcut.bias" in new_key:
+ new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
+
+ # Handle resample blocks - change upsamples to upsampler and remove index
+ elif "resample" in new_key or "time_conv" in new_key:
+ # Change the second upsamples to upsampler and remove the index
+ parts = new_key.split(".")
+ # Find the pattern: encoder.down_blocks.X.downsamples.Y.resample...
+ # We want to change it to: encoder.down_blocks.X.downsampler.resample...
+ if len(parts) >= 4 and parts[3] == "upsamples":
+ # Remove the index (parts[4]) and change upsamples to upsampler
+ new_parts = parts[:3] + ["upsampler"] + parts[5:]
+ new_key = ".".join(new_parts)
+
+ new_state_dict[new_key] = value
+ else:
+ # Keep other keys unchanged
+ new_state_dict[key] = value
+
+ with init_empty_weights():
+ vae = AutoencoderKLWan(**vae22_diffusers_config)
+ vae.load_state_dict(new_state_dict, strict=True, assign=True)
+ return vae
+
+
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default=None)
@@ -533,11 +903,26 @@ DTYPE_MAPPING = {
if __name__ == "__main__":
args = get_args()
- transformer = convert_transformer(args.model_type)
- vae = convert_vae()
+ if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
+ transformer = convert_transformer(args.model_type, stage="high_noise_model")
+ transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
+ else:
+ transformer = convert_transformer(args.model_type)
+ transformer_2 = None
+
+ if "Wan2.2" in args.model_type and "TI2V" in args.model_type:
+ vae = convert_vae_22()
+ else:
+ vae = convert_vae()
+
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
- flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
+ if "FLF2V" in args.model_type:
+ flow_shift = 16.0
+ elif "TI2V" in args.model_type:
+ flow_shift = 5.0
+ else:
+ flow_shift = 3.0
scheduler = UniPCMultistepScheduler(
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
)
@@ -547,7 +932,36 @@ if __name__ == "__main__":
dtype = DTYPE_MAPPING[args.dtype]
transformer.to(dtype)
- if "I2V" in args.model_type or "FLF2V" in args.model_type:
+ if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
+ pipe = WanImageToVideoPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.9,
+ )
+ elif "Wan2.2" and "T2V" in args.model_type:
+ pipe = WanPipeline(
+ transformer=transformer,
+ transformer_2=transformer_2,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ boundary_ratio=0.875,
+ )
+ elif "Wan2.2" and "TI2V" in args.model_type:
+ pipe = WanPipeline(
+ transformer=transformer,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ vae=vae,
+ scheduler=scheduler,
+ expand_timesteps=True,
+ )
+ elif "I2V" in args.model_type or "FLF2V" in args.model_type:
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
)
diff --git a/setup.py b/setup.py
index e8df544e0c..799150fd03 100644
--- a/setup.py
+++ b/setup.py
@@ -102,7 +102,7 @@ _deps = [
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
- "huggingface-hub>=0.27.0",
+ "huggingface-hub>=0.34.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
@@ -110,7 +110,7 @@ _deps = [
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
- "k-diffusion>=0.0.12",
+ "k-diffusion==0.0.12",
"torchsde",
"note_seq",
"librosa",
@@ -269,7 +269,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
- version="0.34.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
+ version="0.35.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 81051b9f25..1414d0fc69 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.34.0.dev0"
+__version__ = "0.35.0.dev0"
from typing import TYPE_CHECKING
@@ -34,10 +34,13 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
+ "guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
+ "modular_pipelines": [],
"pipelines": [],
+ "quantizers.pipe_quant_config": ["PipelineQuantizationConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -130,12 +133,29 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
+ _import_structure["guiders"].extend(
+ [
+ "AdaptiveProjectedGuidance",
+ "AutoGuidance",
+ "ClassifierFreeGuidance",
+ "ClassifierFreeZeroStarGuidance",
+ "PerturbedAttentionGuidance",
+ "SkipLayerGuidance",
+ "SmoothedEnergyGuidance",
+ "TangentialClassifierFreeGuidance",
+ ]
+ )
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
+ "FirstBlockCacheConfig",
"HookRegistry",
+ "LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
+ "SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
+ "apply_first_block_cache",
+ "apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -143,6 +163,7 @@ else:
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
+ "AttentionBackendName",
"AuraFlowTransformer2DModel",
"AutoencoderDC",
"AutoencoderKL",
@@ -199,6 +220,7 @@ else:
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
+ "SkyReelsV2Transformer3DModel",
"SparseControlNetModel",
"StableAudioDiTModel",
"StableCascadeUNet",
@@ -217,6 +239,15 @@ else:
"VQModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
+ "attention_backend",
+ ]
+ )
+ _import_structure["modular_pipelines"].extend(
+ [
+ "ComponentsManager",
+ "ComponentSpec",
+ "ModularPipeline",
+ "ModularPipelineBlocks",
]
)
_import_structure["optimization"] = [
@@ -331,6 +362,16 @@ except OptionalDependencyNotAvailable:
]
else:
+ _import_structure["modular_pipelines"].extend(
+ [
+ "FluxAutoBlocks",
+ "FluxModularPipeline",
+ "StableDiffusionXLAutoBlocks",
+ "StableDiffusionXLModularPipeline",
+ "WanAutoBlocks",
+ "WanModularPipeline",
+ ]
+ )
_import_structure["pipelines"].extend(
[
"AllegroPipeline",
@@ -381,6 +422,8 @@ else:
"FluxFillPipeline",
"FluxImg2ImgPipeline",
"FluxInpaintPipeline",
+ "FluxKontextInpaintPipeline",
+ "FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"HiDreamImagePipeline",
@@ -452,6 +495,11 @@ else:
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline",
+ "SkyReelsV2DiffusionForcingPipeline",
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline",
+ "SkyReelsV2ImageToVideoPipeline",
+ "SkyReelsV2Pipeline",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
@@ -541,6 +589,7 @@ else:
]
)
+
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -747,16 +796,32 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
+ from .guiders import (
+ AdaptiveProjectedGuidance,
+ AutoGuidance,
+ ClassifierFreeGuidance,
+ ClassifierFreeZeroStarGuidance,
+ PerturbedAttentionGuidance,
+ SkipLayerGuidance,
+ SmoothedEnergyGuidance,
+ TangentialClassifierFreeGuidance,
+ )
from .hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
HookRegistry,
+ LayerSkipConfig,
PyramidAttentionBroadcastConfig,
+ SmoothedEnergyGuidanceConfig,
apply_faster_cache,
+ apply_first_block_cache,
+ apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
AllegroTransformer3DModel,
AsymmetricAutoencoderKL,
+ AttentionBackendName,
AuraFlowTransformer2DModel,
AutoencoderDC,
AutoencoderKL,
@@ -813,6 +878,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3ControlNetModel,
SD3MultiControlNetModel,
SD3Transformer2DModel,
+ SkyReelsV2Transformer3DModel,
SparseControlNetModel,
StableAudioDiTModel,
T2IAdapter,
@@ -830,6 +896,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
VQModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
+ attention_backend,
+ )
+ from .modular_pipelines import (
+ ComponentsManager,
+ ComponentSpec,
+ ModularPipeline,
+ ModularPipelineBlocks,
)
from .optimization import (
get_constant_schedule,
@@ -927,6 +1000,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
+ from .modular_pipelines import (
+ FluxAutoBlocks,
+ FluxModularPipeline,
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLModularPipeline,
+ WanAutoBlocks,
+ WanModularPipeline,
+ )
from .pipelines import (
AllegroPipeline,
AltDiffusionImg2ImgPipeline,
@@ -974,6 +1055,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
HiDreamImagePipeline,
@@ -1045,6 +1128,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py
index 4b8b15368c..2a08f091d9 100644
--- a/src/diffusers/callbacks.py
+++ b/src/diffusers/callbacks.py
@@ -207,3 +207,38 @@ class IPAdapterScaleCutoffCallback(PipelineCallback):
if step_index == cutoff_step:
pipeline.set_ip_adapter_scale(0.0)
return callback_kwargs
+
+
+class SD3CFGCutoffCallback(PipelineCallback):
+ """
+ Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
+ `cutoff_step_index`), this callback will disable the CFG.
+
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
+ """
+
+ tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
+
+ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
+ cutoff_step_ratio = self.config.cutoff_step_ratio
+ cutoff_step_index = self.config.cutoff_step_index
+
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
+ cutoff_step = (
+ cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
+ )
+
+ if step_index == cutoff_step:
+ prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
+ prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
+
+ pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
+ pooled_prompt_embeds = pooled_prompt_embeds[
+ -1:
+ ] # "-1" denotes the embeddings for conditional pooled text tokens.
+
+ pipeline._guidance_scale = 0.0
+
+ callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
+ callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
+ return callback_kwargs
diff --git a/src/diffusers/commands/custom_blocks.py b/src/diffusers/commands/custom_blocks.py
new file mode 100644
index 0000000000..43d9ea8857
--- /dev/null
+++ b/src/diffusers/commands/custom_blocks.py
@@ -0,0 +1,134 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage example:
+ TODO
+"""
+
+import ast
+import importlib.util
+import os
+from argparse import ArgumentParser, Namespace
+from pathlib import Path
+
+from ..utils import logging
+from . import BaseDiffusersCLICommand
+
+
+EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
+CONFIG = "config.json"
+
+
+def conversion_command_factory(args: Namespace):
+ return CustomBlocksCommand(args.block_module_name, args.block_class_name)
+
+
+class CustomBlocksCommand(BaseDiffusersCLICommand):
+ @staticmethod
+ def register_subcommand(parser: ArgumentParser):
+ conversion_parser = parser.add_parser("custom_blocks")
+ conversion_parser.add_argument(
+ "--block_module_name",
+ type=str,
+ default="block.py",
+ help="Module filename in which the custom block will be implemented.",
+ )
+ conversion_parser.add_argument(
+ "--block_class_name",
+ type=str,
+ default=None,
+ help="Name of the custom block. If provided None, we will try to infer it.",
+ )
+ conversion_parser.set_defaults(func=conversion_command_factory)
+
+ def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
+ self.logger = logging.get_logger("diffusers-cli/custom_blocks")
+ self.block_module_name = Path(block_module_name)
+ self.block_class_name = block_class_name
+
+ def run(self):
+ # determine the block to be saved.
+ out = self._get_class_names(self.block_module_name)
+ classes_found = list({cls for cls, _ in out})
+
+ if self.block_class_name is not None:
+ child_class, parent_class = self._choose_block(out, self.block_class_name)
+ if child_class is None and parent_class is None:
+ raise ValueError(
+ "`block_class_name` could not be retrieved. Available classes from "
+ f"{self.block_module_name}:\n{classes_found}"
+ )
+ else:
+ self.logger.info(
+ f"Found classes: {classes_found} will be using {classes_found[0]}. "
+ "If this needs to be changed, re-run the command specifying `block_class_name`."
+ )
+ child_class, parent_class = out[0][0], out[0][1]
+
+ # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
+ # the user is responsible for running it, so I guess that is safe?
+ module_name = f"__dynamic__{self.block_module_name.stem}"
+ spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ getattr(module, child_class)().save_pretrained(os.getcwd())
+
+ # or, we could create it manually.
+ # automap = self._create_automap(parent_class=parent_class, child_class=child_class)
+ # with open(CONFIG, "w") as f:
+ # json.dump(automap, f)
+ with open("requirements.txt", "w") as f:
+ f.write("")
+
+ def _choose_block(self, candidates, chosen=None):
+ for cls, base in candidates:
+ if cls == chosen:
+ return cls, base
+ return None, None
+
+ def _get_class_names(self, file_path):
+ source = file_path.read_text(encoding="utf-8")
+ try:
+ tree = ast.parse(source, filename=file_path)
+ except SyntaxError as e:
+ raise ValueError(f"Could not parse {file_path!r}: {e}") from e
+
+ results: list[tuple[str, str]] = []
+ for node in tree.body:
+ if not isinstance(node, ast.ClassDef):
+ continue
+
+ # extract all base names for this class
+ base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
+
+ # for each allowed base that appears in the class's bases, emit a tuple
+ for allowed in EXPECTED_PARENT_CLASSES:
+ if allowed in base_names:
+ results.append((node.name, allowed))
+
+ return results
+
+ def _get_base_name(self, node: ast.expr):
+ if isinstance(node, ast.Name):
+ return node.id
+ elif isinstance(node, ast.Attribute):
+ val = self._get_base_name(node.value)
+ return f"{val}.{node.attr}" if val else node.attr
+ return None
+
+ def _create_automap(self, parent_class, child_class):
+ module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
+ auto_map = {f"{parent_class}": f"{module}.{child_class}"}
+ return {"auto_map": auto_map}
diff --git a/src/diffusers/commands/diffusers_cli.py b/src/diffusers/commands/diffusers_cli.py
index 3c744c5c4c..a27ac24f2a 100644
--- a/src/diffusers/commands/diffusers_cli.py
+++ b/src/diffusers/commands/diffusers_cli.py
@@ -15,6 +15,7 @@
from argparse import ArgumentParser
+from .custom_blocks import CustomBlocksCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -26,6 +27,7 @@ def main():
# Register commands
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
+ CustomBlocksCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()
diff --git a/src/diffusers/commands/fp16_safetensors.py b/src/diffusers/commands/fp16_safetensors.py
index ef60f237ae..41739261e5 100644
--- a/src/diffusers/commands/fp16_safetensors.py
+++ b/src/diffusers/commands/fp16_safetensors.py
@@ -59,7 +59,7 @@ class FP16SafetensorsCommand(BaseDiffusersCLICommand):
conversion_parser.add_argument(
"--use_auth_token",
action="store_true",
- help="When working with checkpoints having private visibility. When used `huggingface-cli login` needs to be run beforehand.",
+ help="When working with checkpoints having private visibility. When used `hf auth login` needs to be run beforehand.",
)
conversion_parser.set_defaults(func=conversion_command_factory)
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
index f9b652bbc0..540aab0307 100644
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -176,6 +176,7 @@ class ConfigMixin:
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+ subfolder = kwargs.pop("subfolder", None)
self._upload_folder(
save_directory,
@@ -183,6 +184,7 @@ class ConfigMixin:
token=token,
commit_message=commit_message,
create_pr=create_pr,
+ subfolder=subfolder,
)
@classmethod
@@ -405,7 +407,7 @@ class ConfigMixin:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
- " token having permission to this repo with `token` or log in with `huggingface-cli login`."
+ " token having permission to this repo with `token` or log in with `hf auth login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
@@ -601,6 +603,10 @@ class ConfigMixin:
value = value.tolist()
elif isinstance(value, Path):
value = value.as_posix()
+ elif hasattr(value, "to_dict") and callable(value.to_dict):
+ value = value.to_dict()
+ elif isinstance(value, list):
+ value = [to_json_saveable(v) for v in value]
return value
if "quantization_config" in config_dict:
@@ -757,4 +763,7 @@ class LegacyConfigMixin(ConfigMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
- return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
+ if remapped_class is cls:
+ return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
+ else:
+ return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py
index f35353c49e..3d14a8b3e0 100644
--- a/src/diffusers/dependency_versions_table.py
+++ b/src/diffusers/dependency_versions_table.py
@@ -9,7 +9,7 @@ deps = {
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
- "huggingface-hub": "huggingface-hub>=0.27.0",
+ "huggingface-hub": "huggingface-hub>=0.34.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
@@ -17,7 +17,7 @@ deps = {
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
- "k-diffusion": "k-diffusion>=0.0.12",
+ "k-diffusion": "k-diffusion==0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",
diff --git a/src/diffusers/guiders/__init__.py b/src/diffusers/guiders/__init__.py
new file mode 100644
index 0000000000..1c288f00f0
--- /dev/null
+++ b/src/diffusers/guiders/__init__.py
@@ -0,0 +1,39 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+from ..utils import is_torch_available
+
+
+if is_torch_available():
+ from .adaptive_projected_guidance import AdaptiveProjectedGuidance
+ from .auto_guidance import AutoGuidance
+ from .classifier_free_guidance import ClassifierFreeGuidance
+ from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
+ from .perturbed_attention_guidance import PerturbedAttentionGuidance
+ from .skip_layer_guidance import SkipLayerGuidance
+ from .smoothed_energy_guidance import SmoothedEnergyGuidance
+ from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
+
+ GuiderType = Union[
+ AdaptiveProjectedGuidance,
+ AutoGuidance,
+ ClassifierFreeGuidance,
+ ClassifierFreeZeroStarGuidance,
+ PerturbedAttentionGuidance,
+ SkipLayerGuidance,
+ SmoothedEnergyGuidance,
+ TangentialClassifierFreeGuidance,
+ ]
diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py
new file mode 100644
index 0000000000..81137db106
--- /dev/null
+++ b/src/diffusers/guiders/adaptive_projected_guidance.py
@@ -0,0 +1,188 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AdaptiveProjectedGuidance(BaseGuidance):
+ """
+ Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ adaptive_projected_guidance_momentum (`float`, defaults to `None`):
+ The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
+ adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ adaptive_projected_guidance_momentum: Optional[float] = None,
+ adaptive_projected_guidance_rescale: float = 15.0,
+ eta: float = 1.0,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
+ self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
+ self.eta = eta
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+ self.momentum_buffer = None
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ if self._step == 0:
+ if self.adaptive_projected_guidance_momentum is not None:
+ self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pred = None
+
+ if not self._is_apg_enabled():
+ pred = pred_cond
+ else:
+ pred = normalized_guidance(
+ pred_cond,
+ pred_uncond,
+ self.guidance_scale,
+ self.momentum_buffer,
+ self.eta,
+ self.adaptive_projected_guidance_rescale,
+ self.use_original_formulation,
+ )
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_apg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_apg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+class MomentumBuffer:
+ def __init__(self, momentum: float):
+ self.momentum = momentum
+ self.running_average = 0
+
+ def update(self, update_value: torch.Tensor):
+ new_average = self.momentum * self.running_average
+ self.running_average = update_value + new_average
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor,
+ pred_uncond: torch.Tensor,
+ guidance_scale: float,
+ momentum_buffer: Optional[MomentumBuffer] = None,
+ eta: float = 1.0,
+ norm_threshold: float = 0.0,
+ use_original_formulation: bool = False,
+):
+ diff = pred_cond - pred_uncond
+ dim = [-i for i in range(1, len(diff.shape))]
+
+ if momentum_buffer is not None:
+ momentum_buffer.update(diff)
+ diff = momentum_buffer.running_average
+
+ if norm_threshold > 0:
+ ones = torch.ones_like(diff)
+ diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
+ scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
+ diff = diff * scale_factor
+
+ v0, v1 = diff.double(), pred_cond.double()
+ v1 = torch.nn.functional.normalize(v1, dim=dim)
+ v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
+ normalized_update = diff_orthogonal + eta * diff_parallel
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ pred = pred + guidance_scale * normalized_update
+
+ return pred
diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py
new file mode 100644
index 0000000000..e1642211d3
--- /dev/null
+++ b/src/diffusers/guiders/auto_guidance.py
@@ -0,0 +1,190 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class AutoGuidance(BaseGuidance):
+ """
+ AutoGuidance: https://huggingface.co/papers/2406.02507
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ auto_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
+ provided, `skip_layer_config` must be provided.
+ auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
+ dropout (`float`, *optional*):
+ The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
+ `auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ auto_guidance_layers: Optional[Union[int, List[int]]] = None,
+ auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ dropout: Optional[float] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.auto_guidance_layers = auto_guidance_layers
+ self.auto_guidance_config = auto_guidance_config
+ self.dropout = dropout
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if auto_guidance_layers is None and auto_guidance_config is None:
+ raise ValueError(
+ "Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
+ )
+ if auto_guidance_layers is not None and auto_guidance_config is not None:
+ raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
+ if (dropout is None and auto_guidance_layers is not None) or (
+ dropout is not None and auto_guidance_layers is None
+ ):
+ raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
+
+ if auto_guidance_layers is not None:
+ if isinstance(auto_guidance_layers, int):
+ auto_guidance_layers = [auto_guidance_layers]
+ if not isinstance(auto_guidance_layers, list):
+ raise ValueError(
+ f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
+ )
+ auto_guidance_config = [
+ LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
+ ]
+
+ if isinstance(auto_guidance_config, dict):
+ auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
+
+ if isinstance(auto_guidance_config, LayerSkipConfig):
+ auto_guidance_config = [auto_guidance_config]
+
+ if not isinstance(auto_guidance_config, list):
+ raise ValueError(
+ f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
+ )
+ elif isinstance(next(iter(auto_guidance_config), None), dict):
+ auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
+
+ self.auto_guidance_config = auto_guidance_config
+ self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_ag_enabled() and self.is_unconditional:
+ for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_ag_enabled() and self.is_unconditional:
+ for name in self._auto_guidance_hook_names:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ registry.remove_hook(name, recurse=True)
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pred = None
+
+ if not self._is_ag_enabled():
+ pred = pred_cond
+ else:
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_ag_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_ag_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py
new file mode 100644
index 0000000000..7e72b92fce
--- /dev/null
+++ b/src/diffusers/guiders/classifier_free_guidance.py
@@ -0,0 +1,141 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class ClassifierFreeGuidance(BaseGuidance):
+ """
+ Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
+
+ CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
+ jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
+ inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
+ proposes scaling and shifting the conditional distribution based on the difference between conditional and
+ unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
+
+ Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
+ paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
+ theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
+
+ The intution behind the original formulation can be thought of as moving the conditional distribution estimates
+ further away from the unconditional distribution estimates, while the diffusers-native implementation can be
+ thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
+ the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
+
+ The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
+ paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pred = None
+
+ if not self._is_cfg_enabled():
+ pred = pred_cond
+ else:
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
new file mode 100644
index 0000000000..85d5cc62d4
--- /dev/null
+++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py
@@ -0,0 +1,152 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class ClassifierFreeZeroStarGuidance(BaseGuidance):
+ """
+ Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
+
+ This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
+ guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
+ process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
+ quality of generated images.
+
+ The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ zero_init_steps (`int`, defaults to `1`):
+ The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ zero_init_steps: int = 1,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.zero_init_steps = zero_init_steps
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pred = None
+
+ if self._step < self.zero_init_steps:
+ pred = torch.zeros_like(pred_cond)
+ elif not self._is_cfg_enabled():
+ pred = pred_cond
+ else:
+ pred_cond_flat = pred_cond.flatten(1)
+ pred_uncond_flat = pred_uncond.flatten(1)
+ alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
+ alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
+ pred_uncond = pred_uncond * alpha
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
+ cond_dtype = cond.dtype
+ cond = cond.float()
+ uncond = uncond.float()
+ dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
+ squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ scale = dot_product / squared_norm
+ return scale.to(dtype=cond_dtype)
diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py
new file mode 100644
index 0000000000..9dc83a7f1d
--- /dev/null
+++ b/src/diffusers/guiders/guider_utils.py
@@ -0,0 +1,309 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from huggingface_hub.utils import validate_hf_hub_args
+from typing_extensions import Self
+
+from ..configuration_utils import ConfigMixin
+from ..utils import PushToHubMixin, get_logger
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+GUIDER_CONFIG_NAME = "guider_config.json"
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class BaseGuidance(ConfigMixin, PushToHubMixin):
+ r"""Base class providing the skeleton for implementing guidance techniques."""
+
+ config_name = GUIDER_CONFIG_NAME
+ _input_predictions = None
+ _identifier_key = "__guidance_identifier__"
+
+ def __init__(self, start: float = 0.0, stop: float = 1.0):
+ self._start = start
+ self._stop = stop
+ self._step: int = None
+ self._num_inference_steps: int = None
+ self._timestep: torch.LongTensor = None
+ self._count_prepared = 0
+ self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
+ self._enabled = True
+
+ if not (0.0 <= start < 1.0):
+ raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
+ if not (start <= stop <= 1.0):
+ raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
+
+ if self._input_predictions is None or not isinstance(self._input_predictions, list):
+ raise ValueError(
+ "`_input_predictions` must be a list of required prediction names for the guidance technique."
+ )
+
+ def disable(self):
+ self._enabled = False
+
+ def enable(self):
+ self._enabled = True
+
+ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
+ self._step = step
+ self._num_inference_steps = num_inference_steps
+ self._timestep = timestep
+ self._count_prepared = 0
+
+ def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
+ """
+ Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
+ attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
+ the values of the provided keyword arguments to this method.
+
+ Args:
+ **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
+ to look up the required data provided for preparation.
+
+ If a string is provided, it will be used as the conditional data (or unconditional if used with a
+ guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
+ conditional data identifier and the second element must be the unconditional data identifier or None.
+
+ Example:
+ ```
+ data = {"prompt_embeds": , "negative_prompt_embeds": , "latents": }
+
+ BaseGuidance.set_input_fields(
+ latents="latents",
+ prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
+ )
+ ```
+ """
+ for key, value in kwargs.items():
+ is_string = isinstance(value, str)
+ is_tuple_of_str_with_len_2 = (
+ isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
+ )
+ if not (is_string or is_tuple_of_str_with_len_2):
+ raise ValueError(
+ f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
+ )
+ self._input_fields = kwargs
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ """
+ Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
+ subclasses to implement specific model preparation logic.
+ """
+ self._count_prepared += 1
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ """
+ Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
+ in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
+ modifications made during `prepare_models`.
+ """
+ pass
+
+ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
+ raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
+
+ def __call__(self, data: List["BlockState"]) -> Any:
+ if not all(hasattr(d, "noise_pred") for d in data):
+ raise ValueError("Expected all data to have `noise_pred` attribute.")
+ if len(data) != self.num_conditions:
+ raise ValueError(
+ f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
+ )
+ forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
+ return self.forward(**forward_inputs)
+
+ def forward(self, *args, **kwargs) -> Any:
+ raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
+
+ @property
+ def is_conditional(self) -> bool:
+ raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
+
+ @property
+ def is_unconditional(self) -> bool:
+ return not self.is_conditional
+
+ @property
+ def num_conditions(self) -> int:
+ raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
+
+ @classmethod
+ def _prepare_batch(
+ cls,
+ input_fields: Dict[str, Union[str, Tuple[str, str]]],
+ data: "BlockState",
+ tuple_index: int,
+ identifier: str,
+ ) -> "BlockState":
+ """
+ Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
+ `BaseGuidance` class. It prepares the batch based on the provided tuple index.
+
+ Args:
+ input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
+ A dictionary where the keys are the names of the fields that will be used to store the data once it is
+ prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
+ to look up the required data provided for preparation. If a string is provided, it will be used as the
+ conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
+ length 2 is provided, the first element must be the conditional data identifier and the second element
+ must be the unconditional data identifier or None.
+ data (`BlockState`):
+ The input data to be prepared.
+ tuple_index (`int`):
+ The index to use when accessing input fields that are tuples.
+
+ Returns:
+ `BlockState`: The prepared batch of data.
+ """
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+ if input_fields is None:
+ raise ValueError(
+ "Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
+ )
+ data_batch = {}
+ for key, value in input_fields.items():
+ try:
+ if isinstance(value, str):
+ data_batch[key] = getattr(data, value)
+ elif isinstance(value, tuple):
+ data_batch[key] = getattr(data, value[tuple_index])
+ else:
+ # We've already checked that value is a string or a tuple of strings with length 2
+ pass
+ except AttributeError:
+ logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
+ data_batch[cls._identifier_key] = identifier
+ return BlockState(**data_batch)
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ subfolder: Optional[str] = None,
+ return_unused_kwargs=False,
+ **kwargs,
+ ) -> Self:
+ r"""
+ Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
+
+ Parameters:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
+ saved with [`~BaseGuidance.save_pretrained`].
+ subfolder (`str`, *optional*):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_unused_kwargs (`bool`, *optional*, defaults to `False`):
+ Whether kwargs that are not consumed by the Python class should be returned or not.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ output_loading_info(`bool`, *optional*, defaults to `False`):
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
+ local_files_only(`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+
+
+
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login`. You can also activate the special
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
+ firewalled environment.
+
+
+
+ """
+ config, kwargs, commit_hash = cls.load_config(
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ subfolder=subfolder,
+ return_unused_kwargs=True,
+ return_commit_hash=True,
+ **kwargs,
+ )
+ return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save a guider configuration object to a directory so that it can be reloaded using the
+ [`~BaseGuidance.from_pretrained`] class method.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Directory where the configuration JSON file will be saved (will be created if it does not exist).
+ push_to_hub (`bool`, *optional*, defaults to `False`):
+ Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
+ repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
+ namespace).
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
+ """
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ r"""
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
+
+ Args:
+ noise_cfg (`torch.Tensor`):
+ The predicted noise tensor for the guided diffusion process.
+ noise_pred_text (`torch.Tensor`):
+ The predicted noise tensor for the text-guided diffusion process.
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
+ A rescale factor applied to the noise predictions.
+ Returns:
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+ return noise_cfg
diff --git a/src/diffusers/guiders/perturbed_attention_guidance.py b/src/diffusers/guiders/perturbed_attention_guidance.py
new file mode 100644
index 0000000000..1b2256732f
--- /dev/null
+++ b/src/diffusers/guiders/perturbed_attention_guidance.py
@@ -0,0 +1,271 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from ..utils import get_logger
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+
+class PerturbedAttentionGuidance(BaseGuidance):
+ """
+ Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
+
+ The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
+ worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
+ of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
+ attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
+ layers.
+
+ Additional reading:
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
+
+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
+ and implementation details.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ perturbed_guidance_scale (`float`, defaults to `2.8`):
+ The scale parameter for perturbed attention guidance.
+ perturbed_guidance_start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which perturbed attention guidance starts.
+ perturbed_guidance_stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which perturbed attention guidance stops.
+ perturbed_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
+ If not provided, `perturbed_guidance_config` must be provided.
+ perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ # NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
+ # the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
+ # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
+ # for each model architecture.
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ perturbed_guidance_scale: float = 2.8,
+ perturbed_guidance_start: float = 0.01,
+ perturbed_guidance_stop: float = 0.2,
+ perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
+ perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.skip_layer_guidance_scale = perturbed_guidance_scale
+ self.skip_layer_guidance_start = perturbed_guidance_start
+ self.skip_layer_guidance_stop = perturbed_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if perturbed_guidance_config is None:
+ if perturbed_guidance_layers is None:
+ raise ValueError(
+ "`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
+ )
+ perturbed_guidance_config = LayerSkipConfig(
+ indices=perturbed_guidance_layers,
+ fqn="auto",
+ skip_attention=False,
+ skip_attention_scores=True,
+ skip_ff=False,
+ )
+ else:
+ if perturbed_guidance_layers is not None:
+ raise ValueError(
+ "`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
+ )
+
+ if isinstance(perturbed_guidance_config, dict):
+ perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
+
+ if isinstance(perturbed_guidance_config, LayerSkipConfig):
+ perturbed_guidance_config = [perturbed_guidance_config]
+
+ if not isinstance(perturbed_guidance_config, list):
+ raise ValueError(
+ "`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
+ )
+ elif isinstance(next(iter(perturbed_guidance_config), None), dict):
+ perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
+
+ for config in perturbed_guidance_config:
+ if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
+ logger.warning(
+ "Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
+ "Please check your configuration. Modifying the config to match the expected values."
+ )
+ config.skip_attention = False
+ config.skip_attention_scores = True
+ config.skip_ff = False
+
+ self.skip_layer_config = perturbed_guidance_config
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._skip_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_skip: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
+ pred = pred + self.skip_layer_guidance_scale * shift
+ elif not self._is_slg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_skip = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_slg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
+ def _is_slg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py
new file mode 100644
index 0000000000..68a657960a
--- /dev/null
+++ b/src/diffusers/guiders/skip_layer_guidance.py
@@ -0,0 +1,262 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry, LayerSkipConfig
+from ..hooks.layer_skip import _apply_layer_skip_hook
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class SkipLayerGuidance(BaseGuidance):
+ """
+ Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
+
+ Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
+
+ SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
+ skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
+ batch of data, apart from the conditional and unconditional batches already used in CFG
+ ([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
+ based on the difference between conditional without skipping and conditional with skipping predictions.
+
+ The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
+ worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
+ version of the model for the conditional prediction).
+
+ STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
+ generation quality in video diffusion models.
+
+ Additional reading:
+ - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
+
+ The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
+ defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ skip_layer_guidance_scale (`float`, defaults to `2.8`):
+ The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
+ values, but it may also lead to overexposure and saturation.
+ skip_layer_guidance_start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which skip layer guidance starts.
+ skip_layer_guidance_stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which skip layer guidance stops.
+ skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
+ provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
+ 3.5 Medium.
+ skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
+ The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
+ `LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ skip_layer_guidance_scale: float = 2.8,
+ skip_layer_guidance_start: float = 0.01,
+ skip_layer_guidance_stop: float = 0.2,
+ skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
+ skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.skip_layer_guidance_scale = skip_layer_guidance_scale
+ self.skip_layer_guidance_start = skip_layer_guidance_start
+ self.skip_layer_guidance_stop = skip_layer_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if not (0.0 <= skip_layer_guidance_start < 1.0):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
+ )
+ if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
+ )
+
+ if skip_layer_guidance_layers is None and skip_layer_config is None:
+ raise ValueError(
+ "Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
+ )
+ if skip_layer_guidance_layers is not None and skip_layer_config is not None:
+ raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
+
+ if skip_layer_guidance_layers is not None:
+ if isinstance(skip_layer_guidance_layers, int):
+ skip_layer_guidance_layers = [skip_layer_guidance_layers]
+ if not isinstance(skip_layer_guidance_layers, list):
+ raise ValueError(
+ f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
+ )
+ skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
+
+ if isinstance(skip_layer_config, dict):
+ skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
+
+ if isinstance(skip_layer_config, LayerSkipConfig):
+ skip_layer_config = [skip_layer_config]
+
+ if not isinstance(skip_layer_config, list):
+ raise ValueError(
+ f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
+ )
+ elif isinstance(next(iter(skip_layer_config), None), dict):
+ skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
+
+ self.skip_layer_config = skip_layer_config
+ self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ self._count_prepared += 1
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
+ _apply_layer_skip_hook(denoiser, config, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._skip_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_skip: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_slg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_cond_skip
+ pred = pred + self.skip_layer_guidance_scale * shift
+ elif not self._is_slg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_skip = pred_cond - pred_cond_skip
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_slg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_slg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py
new file mode 100644
index 0000000000..d8e8a3cf2f
--- /dev/null
+++ b/src/diffusers/guiders/smoothed_energy_guidance.py
@@ -0,0 +1,251 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from ..hooks import HookRegistry
+from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class SmoothedEnergyGuidance(BaseGuidance):
+ """
+ Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
+
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
+ future without warning or guarantee of reproducibility. This implementation assumes:
+ - Generated images are square (height == width)
+ - The model does not combine different modalities together (e.g., text and image latent streams are not combined
+ together such as Flux)
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ seg_guidance_scale (`float`, defaults to `3.0`):
+ The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
+ values, but it may also lead to overexposure and saturation.
+ seg_blur_sigma (`float`, defaults to `9999999.0`):
+ The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
+ infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
+ seg_blur_threshold_inf (`float`, defaults to `9999.0`):
+ The threshold above which the blur is considered infinite.
+ seg_guidance_start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which smoothed energy guidance starts.
+ seg_guidance_stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which smoothed energy guidance stops.
+ seg_guidance_layers (`int` or `List[int]`, *optional*):
+ The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
+ not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
+ Diffusion 3.5 Medium.
+ seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
+ The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
+ a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.01`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `0.2`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ seg_guidance_scale: float = 2.8,
+ seg_blur_sigma: float = 9999999.0,
+ seg_blur_threshold_inf: float = 9999.0,
+ seg_guidance_start: float = 0.0,
+ seg_guidance_stop: float = 1.0,
+ seg_guidance_layers: Optional[Union[int, List[int]]] = None,
+ seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.seg_guidance_scale = seg_guidance_scale
+ self.seg_blur_sigma = seg_blur_sigma
+ self.seg_blur_threshold_inf = seg_blur_threshold_inf
+ self.seg_guidance_start = seg_guidance_start
+ self.seg_guidance_stop = seg_guidance_stop
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ if not (0.0 <= seg_guidance_start < 1.0):
+ raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
+ if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
+ raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
+
+ if seg_guidance_layers is None and seg_guidance_config is None:
+ raise ValueError(
+ "Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
+ )
+ if seg_guidance_layers is not None and seg_guidance_config is not None:
+ raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
+
+ if seg_guidance_layers is not None:
+ if isinstance(seg_guidance_layers, int):
+ seg_guidance_layers = [seg_guidance_layers]
+ if not isinstance(seg_guidance_layers, list):
+ raise ValueError(
+ f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
+ )
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
+
+ if isinstance(seg_guidance_config, dict):
+ seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
+
+ if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
+ seg_guidance_config = [seg_guidance_config]
+
+ if not isinstance(seg_guidance_config, list):
+ raise ValueError(
+ f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
+ )
+ elif isinstance(next(iter(seg_guidance_config), None), dict):
+ seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
+
+ self.seg_guidance_config = seg_guidance_config
+ self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
+
+ def prepare_models(self, denoiser: torch.nn.Module) -> None:
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
+ for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
+ _apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
+
+ def cleanup_models(self, denoiser: torch.nn.Module):
+ if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
+ registry = HookRegistry.check_if_exists_or_initialize(denoiser)
+ # Remove the hooks after inference
+ for hook_name in self._seg_layer_hook_names:
+ registry.remove_hook(hook_name, recurse=True)
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ if self.num_conditions == 1:
+ tuple_indices = [0]
+ input_predictions = ["pred_cond"]
+ elif self.num_conditions == 2:
+ tuple_indices = [0, 1]
+ input_predictions = (
+ ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
+ )
+ else:
+ tuple_indices = [0, 1, 0]
+ input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(
+ self,
+ pred_cond: torch.Tensor,
+ pred_uncond: Optional[torch.Tensor] = None,
+ pred_cond_seg: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ pred = None
+
+ if not self._is_cfg_enabled() and not self._is_seg_enabled():
+ pred = pred_cond
+ elif not self._is_cfg_enabled():
+ shift = pred_cond - pred_cond_seg
+ pred = pred_cond if self.use_original_formulation else pred_cond_seg
+ pred = pred + self.seg_guidance_scale * shift
+ elif not self._is_seg_enabled():
+ shift = pred_cond - pred_uncond
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift
+ else:
+ shift = pred_cond - pred_uncond
+ shift_seg = pred_cond - pred_cond_seg
+ pred = pred_cond if self.use_original_formulation else pred_uncond
+ pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._count_prepared == 1 or self._count_prepared == 3
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_cfg_enabled():
+ num_conditions += 1
+ if self._is_seg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_cfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+ def _is_seg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
+ skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
+ is_within_range = skip_start_step < self._step < skip_stop_step
+
+ is_zero = math.isclose(self.seg_guidance_scale, 0.0)
+
+ return is_within_range and not is_zero
diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py
new file mode 100644
index 0000000000..b3187e5263
--- /dev/null
+++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py
@@ -0,0 +1,143 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import register_to_config
+from .guider_utils import BaseGuidance, rescale_noise_cfg
+
+
+if TYPE_CHECKING:
+ from ..modular_pipelines.modular_pipeline import BlockState
+
+
+class TangentialClassifierFreeGuidance(BaseGuidance):
+ """
+ Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
+
+ Args:
+ guidance_scale (`float`, defaults to `7.5`):
+ The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
+ prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
+ deterioration of image quality.
+ guidance_rescale (`float`, defaults to `0.0`):
+ The rescale factor applied to the noise predictions. This is used to improve image quality and fix
+ overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
+ Flawed](https://huggingface.co/papers/2305.08891).
+ use_original_formulation (`bool`, defaults to `False`):
+ Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
+ we use the diffusers-native implementation that has been in the codebase for a long time. See
+ [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
+ start (`float`, defaults to `0.0`):
+ The fraction of the total number of denoising steps after which guidance starts.
+ stop (`float`, defaults to `1.0`):
+ The fraction of the total number of denoising steps after which guidance stops.
+ """
+
+ _input_predictions = ["pred_cond", "pred_uncond"]
+
+ @register_to_config
+ def __init__(
+ self,
+ guidance_scale: float = 7.5,
+ guidance_rescale: float = 0.0,
+ use_original_formulation: bool = False,
+ start: float = 0.0,
+ stop: float = 1.0,
+ ):
+ super().__init__(start, stop)
+
+ self.guidance_scale = guidance_scale
+ self.guidance_rescale = guidance_rescale
+ self.use_original_formulation = use_original_formulation
+
+ def prepare_inputs(
+ self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
+ ) -> List["BlockState"]:
+ if input_fields is None:
+ input_fields = self._input_fields
+
+ tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
+ data_batches = []
+ for i in range(self.num_conditions):
+ data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
+ data_batches.append(data_batch)
+ return data_batches
+
+ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
+ pred = None
+
+ if not self._is_tcfg_enabled():
+ pred = pred_cond
+ else:
+ pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
+
+ if self.guidance_rescale > 0.0:
+ pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
+
+ return pred, {}
+
+ @property
+ def is_conditional(self) -> bool:
+ return self._num_outputs_prepared == 1
+
+ @property
+ def num_conditions(self) -> int:
+ num_conditions = 1
+ if self._is_tcfg_enabled():
+ num_conditions += 1
+ return num_conditions
+
+ def _is_tcfg_enabled(self) -> bool:
+ if not self._enabled:
+ return False
+
+ is_within_range = True
+ if self._num_inference_steps is not None:
+ skip_start_step = int(self._start * self._num_inference_steps)
+ skip_stop_step = int(self._stop * self._num_inference_steps)
+ is_within_range = skip_start_step <= self._step < skip_stop_step
+
+ is_close = False
+ if self.use_original_formulation:
+ is_close = math.isclose(self.guidance_scale, 0.0)
+ else:
+ is_close = math.isclose(self.guidance_scale, 1.0)
+
+ return is_within_range and not is_close
+
+
+def normalized_guidance(
+ pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
+) -> torch.Tensor:
+ cond_dtype = pred_cond.dtype
+ preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
+ preds = preds.flatten(2)
+ U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
+ Vh_modified = Vh.clone()
+ Vh_modified[:, 1] = 0
+
+ uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
+ x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
+ x_Vh_V = torch.matmul(x_Vh, Vh_modified)
+ pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
+
+ pred = pred_cond if use_original_formulation else pred_uncond
+ shift = pred_cond - pred_uncond
+ pred = pred + guidance_scale * shift
+
+ return pred
diff --git a/src/diffusers/hooks/__init__.py b/src/diffusers/hooks/__init__.py
index 764ceb25b4..525a0747da 100644
--- a/src/diffusers/hooks/__init__.py
+++ b/src/diffusers/hooks/__init__.py
@@ -1,9 +1,26 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from ..utils import is_torch_available
if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
+ from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
+ from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
+ from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
diff --git a/src/diffusers/hooks/_common.py b/src/diffusers/hooks/_common.py
new file mode 100644
index 0000000000..ca7934e5c3
--- /dev/null
+++ b/src/diffusers/hooks/_common.py
@@ -0,0 +1,56 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import torch
+
+from ..models.attention import AttentionModuleMixin, FeedForward, LuminaFeedForward
+from ..models.attention_processor import Attention, MochiAttention
+
+
+_ATTENTION_CLASSES = (Attention, MochiAttention, AttentionModuleMixin)
+_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
+
+_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
+_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
+_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
+
+_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
+ {
+ *_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ *_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ }
+)
+
+# Layers supported for group offloading and layerwise casting
+_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
+ torch.nn.Conv1d,
+ torch.nn.Conv2d,
+ torch.nn.Conv3d,
+ torch.nn.ConvTranspose1d,
+ torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d,
+ torch.nn.Linear,
+ # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
+ # because of double invocation of the same norm layer in CogVideoXLayerNorm
+)
+
+
+def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
+ for submodule_name, submodule in module.named_modules():
+ if submodule_name == fqn:
+ return submodule
+ return None
diff --git a/src/diffusers/hooks/_helpers.py b/src/diffusers/hooks/_helpers.py
new file mode 100644
index 0000000000..9b558ddb21
--- /dev/null
+++ b/src/diffusers/hooks/_helpers.py
@@ -0,0 +1,282 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Type
+
+
+@dataclass
+class AttentionProcessorMetadata:
+ skip_processor_output_fn: Callable[[Any], Any]
+
+
+@dataclass
+class TransformerBlockMetadata:
+ return_hidden_states_index: int = None
+ return_encoder_hidden_states_index: int = None
+
+ _cls: Type = None
+ _cached_parameter_indices: Dict[str, int] = None
+
+ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
+ kwargs = kwargs or {}
+ if identifier in kwargs:
+ return kwargs[identifier]
+ if self._cached_parameter_indices is not None:
+ return args[self._cached_parameter_indices[identifier]]
+ if self._cls is None:
+ raise ValueError("Model class is not set for metadata.")
+ parameters = list(inspect.signature(self._cls.forward).parameters.keys())
+ parameters = parameters[1:] # skip `self`
+ self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
+ if identifier not in self._cached_parameter_indices:
+ raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
+ index = self._cached_parameter_indices[identifier]
+ if index >= len(args):
+ raise ValueError(f"Expected {index} arguments but got {len(args)}.")
+ return args[index]
+
+
+class AttentionProcessorRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
+ cls._register()
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> AttentionProcessorMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_attention_processors_metadata()
+
+
+class TransformerBlockRegistry:
+ _registry = {}
+ # TODO(aryan): this is only required for the time being because we need to do the registrations
+ # for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
+ # import errors because of the models imported in this file.
+ _is_registered = False
+
+ @classmethod
+ def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
+ cls._register()
+ metadata._cls = model_class
+ cls._registry[model_class] = metadata
+
+ @classmethod
+ def get(cls, model_class: Type) -> TransformerBlockMetadata:
+ cls._register()
+ if model_class not in cls._registry:
+ raise ValueError(f"Model class {model_class} not registered.")
+ return cls._registry[model_class]
+
+ @classmethod
+ def _register(cls):
+ if cls._is_registered:
+ return
+ cls._is_registered = True
+ _register_transformer_blocks_metadata()
+
+
+def _register_attention_processors_metadata():
+ from ..models.attention_processor import AttnProcessor2_0
+ from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
+ from ..models.transformers.transformer_flux import FluxAttnProcessor
+ from ..models.transformers.transformer_wan import WanAttnProcessor2_0
+
+ # AttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=AttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
+ ),
+ )
+
+ # CogView4AttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=CogView4AttnProcessor,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
+ ),
+ )
+
+ # WanAttnProcessor2_0
+ AttentionProcessorRegistry.register(
+ model_class=WanAttnProcessor2_0,
+ metadata=AttentionProcessorMetadata(
+ skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
+ ),
+ )
+ # FluxAttnProcessor
+ AttentionProcessorRegistry.register(
+ model_class=FluxAttnProcessor,
+ metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
+ )
+
+
+def _register_transformer_blocks_metadata():
+ from ..models.attention import BasicTransformerBlock
+ from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
+ from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
+ from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
+ from ..models.transformers.transformer_hunyuan_video import (
+ HunyuanVideoSingleTransformerBlock,
+ HunyuanVideoTokenReplaceSingleTransformerBlock,
+ HunyuanVideoTokenReplaceTransformerBlock,
+ HunyuanVideoTransformerBlock,
+ )
+ from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
+ from ..models.transformers.transformer_mochi import MochiTransformerBlock
+ from ..models.transformers.transformer_wan import WanTransformerBlock
+
+ # BasicTransformerBlock
+ TransformerBlockRegistry.register(
+ model_class=BasicTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # CogVideoX
+ TransformerBlockRegistry.register(
+ model_class=CogVideoXBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # CogView4
+ TransformerBlockRegistry.register(
+ model_class=CogView4TransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Flux
+ TransformerBlockRegistry.register(
+ model_class=FluxTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=FluxSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=1,
+ return_encoder_hidden_states_index=0,
+ ),
+ )
+
+ # HunyuanVideo
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+ TransformerBlockRegistry.register(
+ model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # LTXVideo
+ TransformerBlockRegistry.register(
+ model_class=LTXVideoTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+ # Mochi
+ TransformerBlockRegistry.register(
+ model_class=MochiTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=1,
+ ),
+ )
+
+ # Wan
+ TransformerBlockRegistry.register(
+ model_class=WanTransformerBlock,
+ metadata=TransformerBlockMetadata(
+ return_hidden_states_index=0,
+ return_encoder_hidden_states_index=None,
+ ),
+ )
+
+
+# fmt: off
+def _skip_attention___ret___hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ return hidden_states
+
+
+def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
+ hidden_states = kwargs.get("hidden_states", None)
+ encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
+ if hidden_states is None and len(args) > 0:
+ hidden_states = args[0]
+ if encoder_hidden_states is None and len(args) > 1:
+ encoder_hidden_states = args[1]
+ return hidden_states, encoder_hidden_states
+
+
+_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
+_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
+_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
+# not sure what this is yet.
+_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
+# fmt: on
diff --git a/src/diffusers/hooks/faster_cache.py b/src/diffusers/hooks/faster_cache.py
index 1be5e14362..53e5bd792c 100644
--- a/src/diffusers/hooks/faster_cache.py
+++ b/src/diffusers/hooks/faster_cache.py
@@ -18,9 +18,10 @@ from typing import Any, Callable, List, Optional, Tuple
import torch
-from ..models.attention_processor import Attention, MochiAttention
+from ..models.attention import AttentionModuleMixin
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
+from ._common import _ATTENTION_CLASSES
from .hooks import HookRegistry, ModelHook
@@ -29,7 +30,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser"
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block"
-_ATTENTION_CLASSES = (Attention, MochiAttention)
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = (
"^blocks.*attn",
"^transformer_blocks.*attn",
@@ -488,9 +488,10 @@ def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> No
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline.
Args:
- pipeline (`DiffusionPipeline`):
- The diffusion pipeline to apply FasterCache to.
- config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`):
+ module (`torch.nn.Module`):
+ The pytorch module to apply FasterCache to. Typically, this should be a transformer architecture supported
+ in Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
+ config (`FasterCacheConfig`):
The configuration to use for FasterCache.
Example:
@@ -588,7 +589,7 @@ def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCache
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK)
-def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None:
+def _apply_faster_cache_on_attention_class(name: str, module: AttentionModuleMixin, config: FasterCacheConfig) -> None:
is_spatial_self_attention = (
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers)
and config.spatial_attention_block_skip_range is not None
diff --git a/src/diffusers/hooks/first_block_cache.py b/src/diffusers/hooks/first_block_cache.py
new file mode 100644
index 0000000000..862d440593
--- /dev/null
+++ b/src/diffusers/hooks/first_block_cache.py
@@ -0,0 +1,259 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from typing import Tuple, Union
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
+from ._helpers import TransformerBlockRegistry
+from .hooks import BaseState, HookRegistry, ModelHook, StateManager
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
+_FBC_BLOCK_HOOK = "fbc_block_hook"
+
+
+@dataclass
+class FirstBlockCacheConfig:
+ r"""
+ Configuration for [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
+
+ Args:
+ threshold (`float`, defaults to `0.05`):
+ The threshold to determine whether or not a forward pass through all layers of the model is required. A
+ higher threshold usually results in a forward pass through a lower number of layers and faster inference,
+ but might lead to poorer generation quality. A lower threshold may not result in significant generation
+ speedup. The threshold is compared against the absmean difference of the residuals between the current and
+ cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
+ is skipped.
+ """
+
+ threshold: float = 0.05
+
+
+class FBCSharedBlockState(BaseState):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.head_block_residual: torch.Tensor = None
+ self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
+ self.should_compute: bool = True
+
+ def reset(self):
+ self.tail_block_residuals = None
+ self.should_compute = True
+
+
+class FBCHeadBlockHook(ModelHook):
+ _is_stateful = True
+
+ def __init__(self, state_manager: StateManager, threshold: float):
+ self.state_manager = state_manager
+ self.threshold = threshold
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ is_output_tuple = isinstance(output, tuple)
+
+ if is_output_tuple:
+ hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
+ else:
+ hidden_states_residual = output - original_hidden_states
+
+ shared_state: FBCSharedBlockState = self.state_manager.get_state()
+ hidden_states = encoder_hidden_states = None
+ should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
+ shared_state.should_compute = should_compute
+
+ if not should_compute:
+ # Apply caching
+ if is_output_tuple:
+ hidden_states = (
+ shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
+ )
+ else:
+ hidden_states = shared_state.tail_block_residuals[0] + output
+
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ assert is_output_tuple
+ encoder_hidden_states = (
+ shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
+ )
+
+ if is_output_tuple:
+ return_output = [None] * len(output)
+ return_output[self._metadata.return_hidden_states_index] = hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
+ return_output = tuple(return_output)
+ else:
+ return_output = hidden_states
+ output = return_output
+ else:
+ if is_output_tuple:
+ head_block_output = [None] * len(output)
+ head_block_output[0] = output[self._metadata.return_hidden_states_index]
+ head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
+ else:
+ head_block_output = output
+ shared_state.head_block_output = head_block_output
+ shared_state.head_block_residual = hidden_states_residual
+
+ return output
+
+ def reset_state(self, module):
+ self.state_manager.reset()
+ return module
+
+ @torch.compiler.disable
+ def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
+ shared_state = self.state_manager.get_state()
+ if shared_state.head_block_residual is None:
+ return True
+ prev_hidden_states_residual = shared_state.head_block_residual
+ absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
+ prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
+ diff = (absmean / prev_hidden_states_absmean).item()
+ return diff > self.threshold
+
+
+class FBCBlockHook(ModelHook):
+ def __init__(self, state_manager: StateManager, is_tail: bool = False):
+ super().__init__()
+ self.state_manager = state_manager
+ self.is_tail = is_tail
+ self._metadata = None
+
+ def initialize_hook(self, module):
+ unwrapped_module = unwrap_module(module)
+ self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ original_encoder_hidden_states = None
+ if self._metadata.return_encoder_hidden_states_index is not None:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+
+ shared_state = self.state_manager.get_state()
+
+ if shared_state.should_compute:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ if self.is_tail:
+ hidden_states_residual = encoder_hidden_states_residual = None
+ if isinstance(output, tuple):
+ hidden_states_residual = (
+ output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
+ )
+ encoder_hidden_states_residual = (
+ output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
+ )
+ else:
+ hidden_states_residual = output - shared_state.head_block_output
+ shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
+ return output
+
+ if original_encoder_hidden_states is None:
+ return_output = original_hidden_states
+ else:
+ return_output = [None, None]
+ return_output[self._metadata.return_hidden_states_index] = original_hidden_states
+ return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
+ return_output = tuple(return_output)
+ return return_output
+
+
+def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
+ """
+ Applies [First Block
+ Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
+ to a given module.
+
+ First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
+ to implement generically for a wide range of models and has been integrated first for experimental purposes.
+
+ Args:
+ module (`torch.nn.Module`):
+ The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
+ Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
+ config (`FirstBlockCacheConfig`):
+ The configuration to use for applying the FBCache method.
+
+ Example:
+ ```python
+ >>> import torch
+ >>> from diffusers import CogView4Pipeline
+ >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
+
+ >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
+ >>> pipe.to("cuda")
+
+ >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
+
+ >>> prompt = "A photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
+ >>> image.save("output.png")
+ ```
+ """
+
+ state_manager = StateManager(FBCSharedBlockState, (), {})
+ remaining_blocks = []
+
+ for name, submodule in module.named_children():
+ if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
+ continue
+ for index, block in enumerate(submodule):
+ remaining_blocks.append((f"{name}.{index}", block))
+
+ head_block_name, head_block = remaining_blocks.pop(0)
+ tail_block_name, tail_block = remaining_blocks.pop(-1)
+
+ logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
+ _apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
+
+ for name, block in remaining_blocks:
+ logger.debug(f"Applying FBCBlockHook to '{name}'")
+ _apply_fbc_block_hook(block, state_manager)
+
+ logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
+ _apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
+
+
+def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCHeadBlockHook(state_manager, threshold)
+ registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
+
+
+def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = FBCBlockHook(state_manager, is_tail)
+ registry.register_hook(hook, _FBC_BLOCK_HOOK)
diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py
index 7186cb181a..3015409afc 100644
--- a/src/diffusers/hooks/group_offloading.py
+++ b/src/diffusers/hooks/group_offloading.py
@@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
import os
from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass
+from enum import Enum
from typing import Dict, List, Optional, Set, Tuple, Union
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
+from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,17 +39,28 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
-
-_SUPPORTED_PYTORCH_LAYERS = (
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
- torch.nn.Linear,
- # TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
- # because of double invocation of the same norm layer in CogVideoXLayerNorm
-)
+_GROUP_ID_LAZY_LEAF = "lazy_leafs"
# fmt: on
+class GroupOffloadingType(str, Enum):
+ BLOCK_LEVEL = "block_level"
+ LEAF_LEVEL = "leaf_level"
+
+
+@dataclass
+class GroupOffloadingConfig:
+ onload_device: torch.device
+ offload_device: torch.device
+ offload_type: GroupOffloadingType
+ non_blocking: bool
+ record_stream: bool
+ low_cpu_mem_usage: bool
+ num_blocks_per_group: Optional[int] = None
+ offload_to_disk_path: Optional[str] = None
+ stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
+
+
class ModuleGroup:
def __init__(
self,
@@ -62,6 +77,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
+ group_id: Optional[int] = None,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -80,7 +96,10 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
- self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
+ # Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
+ self.group_id = group_id if group_id is not None else str(id(self))
+ short_hash = _compute_group_hash(self.group_id)
+ self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{short_hash}.safetensors")
all_tensors = []
for module in self.modules:
@@ -132,9 +151,58 @@ class ModuleGroup:
finally:
pinned_dict = None
+ def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
+ tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
+ if self.record_stream and current_stream is not None:
+ tensor.data.record_stream(current_stream)
+
+ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
+ for group_module in self.modules:
+ for param in group_module.parameters():
+ source = pinned_memory[param] if pinned_memory else param.data
+ self._transfer_tensor_to_device(param, source, current_stream)
+ for buffer in group_module.buffers():
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
+ self._transfer_tensor_to_device(buffer, source, current_stream)
+
+ for param in self.parameters:
+ source = pinned_memory[param] if pinned_memory else param.data
+ self._transfer_tensor_to_device(param, source, current_stream)
+
+ for buffer in self.buffers:
+ source = pinned_memory[buffer] if pinned_memory else buffer.data
+ self._transfer_tensor_to_device(buffer, source, current_stream)
+
+ def _onload_from_disk(self, current_stream):
+ if self.stream is not None:
+ loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
+
+ for key, tensor_obj in self.key_to_tensor.items():
+ self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
+
+ with self._pinned_memory_tensors() as pinned_memory:
+ for key, tensor_obj in self.key_to_tensor.items():
+ self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
+
+ self.cpu_param_dict.clear()
+
+ else:
+ onload_device = (
+ self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
+ )
+ loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
+ for key, tensor_obj in self.key_to_tensor.items():
+ tensor_obj.data = loaded_tensors[key]
+
+ def _onload_from_memory(self, current_stream):
+ if self.stream is not None:
+ with self._pinned_memory_tensors() as pinned_memory:
+ self._process_tensors_from_modules(pinned_memory, current_stream)
+ else:
+ self._process_tensors_from_modules(None, current_stream)
+
@torch.compiler.disable()
def onload_(self):
- r"""Onloads the group of modules to the onload_device."""
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
@@ -172,67 +240,30 @@ class ModuleGroup:
self.stream.synchronize()
with context:
- if self.stream is not None:
- with self._pinned_memory_tensors() as pinned_memory:
- for group_module in self.modules:
- for param in group_module.parameters():
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
- if self.record_stream:
- param.data.record_stream(current_stream)
- for buffer in group_module.buffers():
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
- if self.record_stream:
- buffer.data.record_stream(current_stream)
-
- for param in self.parameters:
- param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
- if self.record_stream:
- param.data.record_stream(current_stream)
-
- for buffer in self.buffers:
- buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
- if self.record_stream:
- buffer.data.record_stream(current_stream)
-
+ if self.offload_to_disk_path:
+ self._onload_from_disk(current_stream)
else:
- for group_module in self.modules:
- for param in group_module.parameters():
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
- for buffer in group_module.buffers():
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
+ self._onload_from_memory(current_stream)
- for param in self.parameters:
- param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
+ def _offload_to_disk(self):
+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
+ # we perform a write.
+ # Check if the file has been saved in this session or if it already exists on disk.
+ if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
+ os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
+ tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
+ safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
- for buffer in self.buffers:
- buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
- if self.record_stream:
- buffer.data.record_stream(current_stream)
+ # The group is now considered offloaded to disk for the rest of the session.
+ self._is_offloaded_to_disk = True
- @torch.compiler.disable()
- def offload_(self):
- r"""Offloads the group of modules to the offload_device."""
- if self.offload_to_disk_path:
- # TODO: we can potentially optimize this code path by checking if the _all_ the desired
- # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
- # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
- # we perform a write.
- # Check if the file has been saved in this session or if it already exists on disk.
- if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
- os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
- tensors_to_save = {
- key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
- }
- safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
-
- # The group is now considered offloaded to disk for the rest of the session.
- self._is_offloaded_to_disk = True
-
- # We do this to free up the RAM which is still holding the up tensor data.
- for tensor_obj in self.tensor_to_key.keys():
- tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
- return
+ # We do this to free up the RAM which is still holding the up tensor data.
+ for tensor_obj in self.tensor_to_key.keys():
+ tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
+ def _offload_to_memory(self):
torch_accelerator_module = (
getattr(torch, torch.accelerator.current_accelerator().type)
if hasattr(torch, "accelerator")
@@ -257,6 +288,14 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
+ @torch.compiler.disable()
+ def offload_(self):
+ r"""Offloads the group of modules to the offload_device."""
+ if self.offload_to_disk_path:
+ self._offload_to_disk()
+ else:
+ self._offload_to_memory()
+
class GroupOffloadingHook(ModelHook):
r"""
@@ -268,9 +307,12 @@ class GroupOffloadingHook(ModelHook):
_is_stateful = False
- def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
+ def __init__(
+ self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
+ ) -> None:
self.group = group
self.next_group = next_group
+ self.config = config
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
@@ -319,7 +361,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
def initialize_hook(self, module):
def make_execution_order_update_callback(current_name, current_submodule):
def callback():
- logger.debug(f"Adding {current_name} to the execution order")
+ if not torch.compiler.is_compiling():
+ logger.debug(f"Adding {current_name} to the execution order")
self.execution_order.append((current_name, current_submodule))
return callback
@@ -356,12 +399,13 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
# if the missing layers end up being executed in the future.
if execution_order_module_names != self._layer_execution_tracker_module_names:
unexecuted_layers = list(self._layer_execution_tracker_module_names - execution_order_module_names)
- logger.warning(
- "It seems like some layers were not executed during the forward pass. This may lead to problems when "
- "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
- "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
- f"{unexecuted_layers=}"
- )
+ if not torch.compiler.is_compiling():
+ logger.warning(
+ "It seems like some layers were not executed during the forward pass. This may lead to problems when "
+ "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please "
+ "make sure that all layers are executed during the forward pass. The following layers were not executed:\n"
+ f"{unexecuted_layers=}"
+ )
# Remove the layer execution tracker hooks from the submodules
base_module_registry = module._diffusers_hook
@@ -389,7 +433,8 @@ class LazyPrefetchGroupOffloadingHook(ModelHook):
for i in range(num_executed - 1):
name1, _ = self.execution_order[i]
name2, _ = self.execution_order[i + 1]
- logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
+ if not torch.compiler.is_compiling():
+ logger.debug(f"Applying lazy prefetch group offloading from {name1} to {name2}")
group_offloading_hooks[i].next_group = group_offloading_hooks[i + 1].group
group_offloading_hooks[i].next_group.onload_self = False
@@ -416,7 +461,7 @@ def apply_group_offloading(
module: torch.nn.Module,
onload_device: torch.device,
offload_device: torch.device = torch.device("cpu"),
- offload_type: str = "block_level",
+ offload_type: Union[str, GroupOffloadingType] = "block_level",
num_blocks_per_group: Optional[int] = None,
non_blocking: bool = False,
use_stream: bool = False,
@@ -458,7 +503,7 @@ def apply_group_offloading(
The device to which the group of modules are onloaded.
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
- offload_type (`str`, defaults to "block_level"):
+ offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
"block_level".
offload_to_disk_path (`str`, *optional*, defaults to `None`):
@@ -501,6 +546,8 @@ def apply_group_offloading(
```
"""
+ offload_type = GroupOffloadingType(offload_type)
+
stream = None
if use_stream:
if torch.cuda.is_available():
@@ -512,84 +559,45 @@ def apply_group_offloading(
if not use_stream and record_stream:
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
+ if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
+ raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
- if offload_type == "block_level":
- if num_blocks_per_group is None:
- raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
+ config = GroupOffloadingConfig(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type=offload_type,
+ num_blocks_per_group=num_blocks_per_group,
+ non_blocking=non_blocking,
+ stream=stream,
+ record_stream=record_stream,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ offload_to_disk_path=offload_to_disk_path,
+ )
+ _apply_group_offloading(module, config)
- _apply_group_offloading_block_level(
- module=module,
- num_blocks_per_group=num_blocks_per_group,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
- non_blocking=non_blocking,
- stream=stream,
- record_stream=record_stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
- elif offload_type == "leaf_level":
- _apply_group_offloading_leaf_level(
- module=module,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
- non_blocking=non_blocking,
- stream=stream,
- record_stream=record_stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
- )
+
+def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
+ if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
+ _apply_group_offloading_block_level(module, config)
+ elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
+ _apply_group_offloading_leaf_level(module, config)
else:
- raise ValueError(f"Unsupported offload_type: {offload_type}")
+ assert False
-def _apply_group_offloading_block_level(
- module: torch.nn.Module,
- num_blocks_per_group: int,
- offload_device: torch.device,
- onload_device: torch.device,
- non_blocking: bool,
- stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
- record_stream: Optional[bool] = False,
- low_cpu_mem_usage: bool = False,
- offload_to_disk_path: Optional[str] = None,
-) -> None:
+def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
-
- Args:
- module (`torch.nn.Module`):
- The module to which group offloading is applied.
- offload_device (`torch.device`):
- The device to which the group of modules are offloaded. This should typically be the CPU.
- offload_to_disk_path (`str`, *optional*, defaults to `None`):
- The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
- RAM environment settings where a reasonable speed-memory trade-off is desired.
- onload_device (`torch.device`):
- The device to which the group of modules are onloaded.
- non_blocking (`bool`):
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
- and data transfer.
- stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
- for overlapping computation and data transfer.
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
- details.
- low_cpu_mem_usage (`bool`, defaults to `False`):
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
- if stream is not None and num_blocks_per_group != 1:
+
+ if config.stream is not None and config.num_blocks_per_group != 1:
logger.warning(
- f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
+ f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
)
- num_blocks_per_group = 1
+ config.num_blocks_per_group = 1
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -601,20 +609,22 @@ def _apply_group_offloading_block_level(
modules_with_group_offloading.add(name)
continue
- for i in range(0, len(submodule), num_blocks_per_group):
- current_modules = submodule[i : i + num_blocks_per_group]
+ for i in range(0, len(submodule), config.num_blocks_per_group):
+ current_modules = submodule[i : i + config.num_blocks_per_group]
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
group = ModuleGroup(
modules=current_modules,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=current_modules[-1],
onload_leader=current_modules[0],
- non_blocking=non_blocking,
- stream=stream,
- record_stream=record_stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=group_id,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -623,7 +633,7 @@ def _apply_group_offloading_block_level(
# Apply group offloading hooks to the module groups
for i, group in enumerate(matched_module_groups):
for group_module in group.modules:
- _apply_group_offloading_hook(group_module, group, None)
+ _apply_group_offloading_hook(group_module, group, None, config=config)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
@@ -638,9 +648,9 @@ def _apply_group_offloading_block_level(
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
unmatched_group = ModuleGroup(
modules=unmatched_modules,
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=parameters,
@@ -649,74 +659,41 @@ def _apply_group_offloading_block_level(
stream=None,
record_stream=False,
onload_self=True,
+ group_id=f"{module.__class__.__name__}_unmatched_group",
)
- if stream is None:
- _apply_group_offloading_hook(module, unmatched_group, None)
+ if config.stream is None:
+ _apply_group_offloading_hook(module, unmatched_group, None, config=config)
else:
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
+ _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
-def _apply_group_offloading_leaf_level(
- module: torch.nn.Module,
- offload_device: torch.device,
- onload_device: torch.device,
- non_blocking: bool,
- stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
- record_stream: Optional[bool] = False,
- low_cpu_mem_usage: bool = False,
- offload_to_disk_path: Optional[str] = None,
-) -> None:
+def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
reduce memory usage without any performance degradation.
-
- Args:
- module (`torch.nn.Module`):
- The module to which group offloading is applied.
- offload_device (`torch.device`):
- The device to which the group of modules are offloaded. This should typically be the CPU.
- onload_device (`torch.device`):
- The device to which the group of modules are onloaded.
- offload_to_disk_path (`str`, *optional*, defaults to `None`):
- The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
- RAM environment settings where a reasonable speed-memory trade-off is desired.
- non_blocking (`bool`):
- If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
- and data transfer.
- stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
- If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
- for overlapping computation and data transfer.
- record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
- as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
- [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
- details.
- low_cpu_mem_usage (`bool`, defaults to `False`):
- If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
- option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
- the CPU memory is a bottleneck but may counteract the benefits of using streams.
"""
-
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
for name, submodule in module.named_modules():
- if not isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
group = ModuleGroup(
modules=[submodule],
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=submodule,
onload_leader=submodule,
- non_blocking=non_blocking,
- stream=stream,
- record_stream=record_stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=name,
)
- _apply_group_offloading_hook(submodule, group, None)
+ _apply_group_offloading_hook(submodule, group, None, config=config)
modules_with_group_offloading.add(name)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -747,33 +724,33 @@ def _apply_group_offloading_leaf_level(
parameters = parent_to_parameters.get(name, [])
buffers = parent_to_buffers.get(name, [])
parent_module = module_dict[name]
- assert getattr(parent_module, "_diffusers_hook", None) is None
group = ModuleGroup(
modules=[],
- offload_device=offload_device,
- onload_device=onload_device,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
offload_leader=parent_module,
onload_leader=parent_module,
- offload_to_disk_path=offload_to_disk_path,
+ offload_to_disk_path=config.offload_to_disk_path,
parameters=parameters,
buffers=buffers,
- non_blocking=non_blocking,
- stream=stream,
- record_stream=record_stream,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ non_blocking=config.non_blocking,
+ stream=config.stream,
+ record_stream=config.record_stream,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=name,
)
- _apply_group_offloading_hook(parent_module, group, None)
+ _apply_group_offloading_hook(parent_module, group, None, config=config)
- if stream is not None:
+ if config.stream is not None:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
# execution order and apply prefetching in the correct order.
unmatched_group = ModuleGroup(
modules=[],
- offload_device=offload_device,
- onload_device=onload_device,
- offload_to_disk_path=offload_to_disk_path,
+ offload_device=config.offload_device,
+ onload_device=config.onload_device,
+ offload_to_disk_path=config.offload_to_disk_path,
offload_leader=module,
onload_leader=module,
parameters=None,
@@ -781,23 +758,26 @@ def _apply_group_offloading_leaf_level(
non_blocking=False,
stream=None,
record_stream=False,
- low_cpu_mem_usage=low_cpu_mem_usage,
+ low_cpu_mem_usage=config.low_cpu_mem_usage,
onload_self=True,
+ group_id=_GROUP_ID_LAZY_LEAF,
)
- _apply_lazy_group_offloading_hook(module, unmatched_group, None)
+ _apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
def _apply_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
+ *,
+ config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
- hook = GroupOffloadingHook(group, next_group)
+ hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
@@ -805,13 +785,15 @@ def _apply_lazy_group_offloading_hook(
module: torch.nn.Module,
group: ModuleGroup,
next_group: Optional[ModuleGroup] = None,
+ *,
+ config: GroupOffloadingConfig,
) -> None:
registry = HookRegistry.check_if_exists_or_initialize(module)
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if registry.get_hook(_GROUP_OFFLOADING) is None:
- hook = GroupOffloadingHook(group, next_group)
+ hook = GroupOffloadingHook(group, next_group, config=config)
registry.register_hook(hook, _GROUP_OFFLOADING)
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -878,15 +860,54 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
)
-def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
+def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
for submodule in module.modules():
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
- return True
- return False
+ if hasattr(submodule, "_diffusers_hook"):
+ group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
+ if group_offloading_hook is not None:
+ return group_offloading_hook
+ return None
+
+
+def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+ return top_level_group_offload_hook is not None
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
- for submodule in module.modules():
- if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
- return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+ if top_level_group_offload_hook is not None:
+ return top_level_group_offload_hook.config.onload_device
raise ValueError("Group offloading is not enabled for the provided module.")
+
+
+def _compute_group_hash(group_id):
+ hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
+ # first 16 characters for a reasonably short but unique name
+ return hashed_id[:16]
+
+
+def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
+ r"""
+ Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
+ modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
+ modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
+
+ In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
+ and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
+ case where user has applied group offloading at multiple levels, this function will not work as expected.
+
+ There is some performance penalty associated with doing this when non-default streams are used, because we need to
+ retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
+ """
+ top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
+
+ if top_level_group_offload_hook is None:
+ return
+
+ registry = HookRegistry.check_if_exists_or_initialize(module)
+ registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
+ registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
+ registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
+
+ _apply_group_offloading(module, top_level_group_offload_hook.config)
diff --git a/src/diffusers/hooks/hooks.py b/src/diffusers/hooks/hooks.py
index 96231aadc3..6e097e5882 100644
--- a/src/diffusers/hooks/hooks.py
+++ b/src/diffusers/hooks/hooks.py
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
import torch
from ..utils.logging import get_logger
+from ..utils.torch_utils import unwrap_module
logger = get_logger(__name__) # pylint: disable=invalid-name
+class BaseState:
+ def reset(self, *args, **kwargs) -> None:
+ raise NotImplementedError(
+ "BaseState::reset is not implemented. Please implement this method in the derived class."
+ )
+
+
+class StateManager:
+ def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
+ self._state_cls = state_cls
+ self._init_args = init_args if init_args is not None else ()
+ self._init_kwargs = init_kwargs if init_kwargs is not None else {}
+ self._state_cache = {}
+ self._current_context = None
+
+ def get_state(self):
+ if self._current_context is None:
+ raise ValueError("No context is set. Please set a context before retrieving the state.")
+ if self._current_context not in self._state_cache.keys():
+ self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
+ return self._state_cache[self._current_context]
+
+ def set_context(self, name: str) -> None:
+ self._current_context = name
+
+ def reset(self, *args, **kwargs) -> None:
+ for name, state in list(self._state_cache.items()):
+ state.reset(*args, **kwargs)
+ self._state_cache.pop(name)
+ self._current_context = None
+
+
class ModelHook:
r"""
A hook that contains callbacks to be executed just before and after the forward method of a model.
@@ -99,6 +132,14 @@ class ModelHook:
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
return module
+ def _set_context(self, module: torch.nn.Module, name: str) -> None:
+ # Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
+ for attr_name in dir(self):
+ attr = getattr(self, attr_name)
+ if isinstance(attr, StateManager):
+ attr.set_context(name)
+ return module
+
class HookFunctionReference:
def __init__(self) -> None:
@@ -211,9 +252,10 @@ class HookRegistry:
hook.reset_state(self._module_ref)
if recurse:
- for module_name, module in self._module_ref.named_modules():
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
+ module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook.reset_stateful_hooks(recurse=False)
@@ -223,6 +265,19 @@ class HookRegistry:
module._diffusers_hook = cls(module)
return module._diffusers_hook
+ def _set_context(self, name: Optional[str] = None) -> None:
+ for hook_name in reversed(self._hook_order):
+ hook = self.hooks[hook_name]
+ if hook._is_stateful:
+ hook._set_context(self._module_ref, name)
+
+ for module_name, module in unwrap_module(self._module_ref).named_modules():
+ if module_name == "":
+ continue
+ module = unwrap_module(module)
+ if hasattr(module, "_diffusers_hook"):
+ module._diffusers_hook._set_context(name)
+
def __repr__(self) -> str:
registry_repr = ""
for i, hook_name in enumerate(self._hook_order):
diff --git a/src/diffusers/hooks/layer_skip.py b/src/diffusers/hooks/layer_skip.py
new file mode 100644
index 0000000000..0ce02e987d
--- /dev/null
+++ b/src/diffusers/hooks/layer_skip.py
@@ -0,0 +1,263 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import asdict, dataclass
+from typing import Callable, List, Optional
+
+import torch
+
+from ..utils import get_logger
+from ..utils.torch_utils import unwrap_module
+from ._common import (
+ _ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _ATTENTION_CLASSES,
+ _FEEDFORWARD_CLASSES,
+ _get_submodule_from_fqn,
+)
+from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_LAYER_SKIP_HOOK = "layer_skip_hook"
+
+
+# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
+# either remove or make it serializable
+@dataclass
+class LayerSkipConfig:
+ r"""
+ Configuration for skipping internal transformer blocks when executing a transformer model.
+
+ Args:
+ indices (`List[int]`):
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
+ fqn (`str`, defaults to `"auto"`):
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
+ provide the correct fqn.
+ skip_attention (`bool`, defaults to `True`):
+ Whether to skip attention blocks.
+ skip_ff (`bool`, defaults to `True`):
+ Whether to skip feed-forward blocks.
+ skip_attention_scores (`bool`, defaults to `False`):
+ Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
+ projections as the output of scaled dot product attention.
+ dropout (`float`, defaults to `1.0`):
+ The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
+ meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
+ skipped layers are fully retained, which is equivalent to not skipping any layers.
+ """
+
+ indices: List[int]
+ fqn: str = "auto"
+ skip_attention: bool = True
+ skip_attention_scores: bool = False
+ skip_ff: bool = True
+ dropout: float = 1.0
+
+ def __post_init__(self):
+ if not (0 <= self.dropout <= 1):
+ raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
+ if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+
+ def to_dict(self):
+ return asdict(self)
+
+ @staticmethod
+ def from_dict(data: dict) -> "LayerSkipConfig":
+ return LayerSkipConfig(**data)
+
+
+class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
+ def __torch_function__(self, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func is torch.nn.functional.scaled_dot_product_attention:
+ query = kwargs.get("query", None)
+ key = kwargs.get("key", None)
+ value = kwargs.get("value", None)
+ query = query if query is not None else args[0]
+ key = key if key is not None else args[1]
+ value = value if value is not None else args[2]
+ # If the Q sequence length does not match KV sequence length, methods like
+ # Perturbed Attention Guidance cannot be used (because the caller expects
+ # the same sequence length as Q, but if we return V here, it will not match).
+ # When Q.shape[2] != V.shape[2], PAG will essentially not be applied and
+ # the overall effect would that be of normal CFG with a scale of (guidance_scale + perturbed_guidance_scale).
+ if query.shape[2] == value.shape[2]:
+ return value
+ return func(*args, **kwargs)
+
+
+class AttentionProcessorSkipHook(ModelHook):
+ def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
+ self.skip_processor_output_fn = skip_processor_output_fn
+ self.skip_attention_scores = skip_attention_scores
+ self.dropout = dropout
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if self.skip_attention_scores:
+ if not math.isclose(self.dropout, 1.0):
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+ with AttentionScoreSkipFunctionMode():
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ else:
+ if math.isclose(self.dropout, 1.0):
+ output = self.skip_processor_output_fn(module, *args, **kwargs)
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+class FeedForwardSkipHook(ModelHook):
+ def __init__(self, dropout: float):
+ super().__init__()
+ self.dropout = dropout
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if math.isclose(self.dropout, 1.0):
+ output = kwargs.get("hidden_states", None)
+ if output is None:
+ output = kwargs.get("x", None)
+ if output is None and len(args) > 0:
+ output = args[0]
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+class TransformerBlockSkipHook(ModelHook):
+ def __init__(self, dropout: float):
+ super().__init__()
+ self.dropout = dropout
+
+ def initialize_hook(self, module):
+ self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
+ return module
+
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
+ if math.isclose(self.dropout, 1.0):
+ original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
+ if self._metadata.return_encoder_hidden_states_index is None:
+ output = original_hidden_states
+ else:
+ original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
+ "encoder_hidden_states", args, kwargs
+ )
+ output = (original_hidden_states, original_encoder_hidden_states)
+ else:
+ output = self.fn_ref.original_forward(*args, **kwargs)
+ output = torch.nn.functional.dropout(output, p=self.dropout)
+ return output
+
+
+def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
+ r"""
+ Apply layer skipping to internal layers of a transformer.
+
+ Args:
+ module (`torch.nn.Module`):
+ The transformer model to which the layer skip hook should be applied.
+ config (`LayerSkipConfig`):
+ The configuration for the layer skip hook.
+
+ Example:
+
+ ```python
+ >>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
+
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
+ >>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
+ >>> apply_layer_skip_hook(transformer, config)
+ ```
+ """
+ _apply_layer_skip_hook(module, config)
+
+
+def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
+ name = name or _LAYER_SKIP_HOOK
+
+ if config.skip_attention and config.skip_attention_scores:
+ raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
+ if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
+ raise ValueError(
+ "Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
+ )
+
+ if config.fqn == "auto":
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
+ if hasattr(module, identifier):
+ config.fqn = identifier
+ break
+ else:
+ raise ValueError(
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
+ )
+
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
+ if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
+ raise ValueError(
+ f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
+ f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
+ )
+ if len(config.indices) == 0:
+ raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
+
+ blocks_found = False
+ for i, block in enumerate(transformer_blocks):
+ if i not in config.indices:
+ continue
+
+ blocks_found = True
+
+ if config.skip_attention and config.skip_ff:
+ logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
+ registry = HookRegistry.check_if_exists_or_initialize(block)
+ hook = TransformerBlockSkipHook(config.dropout)
+ registry.register_hook(hook, name)
+
+ elif config.skip_attention or config.skip_attention_scores:
+ for submodule_name, submodule in block.named_modules():
+ if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
+ logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
+ output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
+ registry.register_hook(hook, name)
+
+ if config.skip_ff:
+ for submodule_name, submodule in block.named_modules():
+ if isinstance(submodule, _FEEDFORWARD_CLASSES):
+ logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
+ hook = FeedForwardSkipHook(config.dropout)
+ registry.register_hook(hook, name)
+
+ if not blocks_found:
+ raise ValueError(
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
+ )
diff --git a/src/diffusers/hooks/layerwise_casting.py b/src/diffusers/hooks/layerwise_casting.py
index 1747a5c489..a036ad37dc 100644
--- a/src/diffusers/hooks/layerwise_casting.py
+++ b/src/diffusers/hooks/layerwise_casting.py
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Type, Union
import torch
from ..utils import get_logger, is_peft_available, is_peft_version
+from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -27,12 +28,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
# fmt: off
_LAYERWISE_CASTING_HOOK = "layerwise_casting"
_PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
-SUPPORTED_PYTORCH_LAYERS = (
- torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
- torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
- torch.nn.Linear,
-)
-
DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
# fmt: on
@@ -186,7 +181,7 @@ def _apply_layerwise_casting(
logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
return
- if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ if isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
return
diff --git a/src/diffusers/hooks/pyramid_attention_broadcast.py b/src/diffusers/hooks/pyramid_attention_broadcast.py
index bbdd1c3f68..ee3f410331 100644
--- a/src/diffusers/hooks/pyramid_attention_broadcast.py
+++ b/src/diffusers/hooks/pyramid_attention_broadcast.py
@@ -18,8 +18,15 @@ from typing import Any, Callable, Optional, Tuple, Union
import torch
+from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
+from ._common import (
+ _ATTENTION_CLASSES,
+ _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+ _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
+)
from .hooks import HookRegistry, ModelHook
@@ -27,10 +34,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
_PYRAMID_ATTENTION_BROADCAST_HOOK = "pyramid_attention_broadcast"
-_ATTENTION_CLASSES = (Attention, MochiAttention)
-_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks")
-_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
-_CROSS_ATTENTION_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks")
@dataclass
@@ -60,11 +63,11 @@ class PyramidAttentionBroadcastConfig:
cross_attention_timestep_skip_range (`Tuple[int, int]`, defaults to `(100, 800)`):
The range of timesteps to skip in the cross-attention layer. The attention computations will be
conditionally skipped if the current timestep is within the specified range.
- spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ spatial_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a spatial attention layer.
- temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks",)`):
+ temporal_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a temporal attention layer.
- cross_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks", "transformer_blocks")`):
+ cross_attention_block_identifiers (`Tuple[str, ...]`):
The identifiers to match against the layer names to determine if the layer is a cross-attention layer.
"""
@@ -76,9 +79,9 @@ class PyramidAttentionBroadcastConfig:
temporal_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
cross_attention_timestep_skip_range: Tuple[int, int] = (100, 800)
- spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS
- temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS
- cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_ATTENTION_BLOCK_IDENTIFIERS
+ spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS
+ temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS
+ cross_attention_block_identifiers: Tuple[str, ...] = _CROSS_TRANSFORMER_BLOCK_IDENTIFIERS
current_timestep_callback: Callable[[], int] = None
@@ -227,7 +230,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2
for name, submodule in module.named_modules():
- if not isinstance(submodule, _ATTENTION_CLASSES):
+ if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
diff --git a/src/diffusers/hooks/smoothed_energy_guidance_utils.py b/src/diffusers/hooks/smoothed_energy_guidance_utils.py
new file mode 100644
index 0000000000..622f607647
--- /dev/null
+++ b/src/diffusers/hooks/smoothed_energy_guidance_utils.py
@@ -0,0 +1,167 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from dataclasses import asdict, dataclass
+from typing import List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from ..utils import get_logger
+from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
+from .hooks import HookRegistry, ModelHook
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
+
+
+@dataclass
+class SmoothedEnergyGuidanceConfig:
+ r"""
+ Configuration for skipping internal transformer blocks when executing a transformer model.
+
+ Args:
+ indices (`List[int]`):
+ The indices of the layer to skip. This is typically the first layer in the transformer block.
+ fqn (`str`, defaults to `"auto"`):
+ The fully qualified name identifying the stack of transformer blocks. Typically, this is
+ `transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
+ For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
+ provide the correct fqn.
+ _query_proj_identifiers (`List[str]`, defaults to `None`):
+ The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
+ `None`, `to_q` is used by default.
+ """
+
+ indices: List[int]
+ fqn: str = "auto"
+ _query_proj_identifiers: List[str] = None
+
+ def to_dict(self):
+ return asdict(self)
+
+ @staticmethod
+ def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
+ return SmoothedEnergyGuidanceConfig(**data)
+
+
+class SmoothedEnergyGuidanceHook(ModelHook):
+ def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
+ super().__init__()
+ self.blur_sigma = blur_sigma
+ self.blur_threshold_inf = blur_threshold_inf
+
+ def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
+ # Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
+ kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
+ smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
+ return smoothed_output
+
+
+def _apply_smoothed_energy_guidance_hook(
+ module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
+) -> None:
+ name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
+
+ if config.fqn == "auto":
+ for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
+ if hasattr(module, identifier):
+ config.fqn = identifier
+ break
+ else:
+ raise ValueError(
+ "Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
+ "`fqn` (fully qualified name) that identifies a stack of transformer blocks."
+ )
+
+ if config._query_proj_identifiers is None:
+ config._query_proj_identifiers = ["to_q"]
+
+ transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
+ blocks_found = False
+ for i, block in enumerate(transformer_blocks):
+ if i not in config.indices:
+ continue
+
+ blocks_found = True
+
+ for submodule_name, submodule in block.named_modules():
+ if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
+ continue
+ for identifier in config._query_proj_identifiers:
+ query_proj = getattr(submodule, identifier, None)
+ if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
+ continue
+ logger.debug(
+ f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
+ )
+ registry = HookRegistry.check_if_exists_or_initialize(query_proj)
+ hook = SmoothedEnergyGuidanceHook(blur_sigma)
+ registry.register_hook(hook, name)
+
+ if not blocks_found:
+ raise ValueError(
+ f"Could not find any transformer blocks matching the provided indices {config.indices} and "
+ f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
+ )
+
+
+# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
+def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
+ """
+ This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
+ However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
+ implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
+ assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
+ Smoothed Energy Guidance.
+
+ SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
+ future without warning or guarantee of reproducibility.
+ """
+ assert query.ndim == 3
+
+ is_inf = sigma > sigma_threshold_inf
+ batch_size, seq_len, embed_dim = query.shape
+
+ seq_len_sqrt = int(math.sqrt(seq_len))
+ num_square_tokens = seq_len_sqrt * seq_len_sqrt
+ query_slice = query[:, :num_square_tokens, :]
+ query_slice = query_slice.permute(0, 2, 1)
+ query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
+
+ if is_inf:
+ kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
+ kernel_size_half = (kernel_size - 1) / 2
+
+ x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
+ pdf = torch.exp(-0.5 * (x / sigma).pow(2))
+ kernel1d = pdf / pdf.sum()
+ kernel1d = kernel1d.to(query)
+ kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
+ kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
+
+ padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
+ query_slice = F.pad(query_slice, padding, mode="reflect")
+ query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
+ else:
+ query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
+
+ query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
+ query_slice = query_slice.permute(0, 2, 1)
+ query[:, :num_square_tokens, :] = query_slice.clone()
+
+ return query
diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py
index 84c6d9f32c..9f46b5acd3 100644
--- a/src/diffusers/loaders/__init__.py
+++ b/src/diffusers/loaders/__init__.py
@@ -78,12 +78,14 @@ if is_torch_available():
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
+ "SkyReelsV2LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin",
+ "ModularIPAdapterMixin",
]
_import_structure["peft"] = ["PeftAdapterMixin"]
@@ -101,6 +103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin,
+ ModularIPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
@@ -117,6 +120,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Mochi1LoraLoaderMixin,
SanaLoraLoaderMixin,
SD3LoraLoaderMixin,
+ SkyReelsV2LoraLoaderMixin,
StableDiffusionLoraLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
WanLoraLoaderMixin,
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index 521cb3b6fd..dca4758ba0 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -40,8 +40,6 @@ if is_transformers_available():
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
- FluxAttnProcessor2_0,
- FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
@@ -354,6 +352,256 @@ class IPAdapterMixin:
self.unet.set_attn_processor(attn_procs)
+class ModularIPAdapterMixin:
+ """Mixin for handling IP Adapters."""
+
+ @validate_hf_hub_args
+ def load_ip_adapter(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
+ subfolder: Union[str, List[str]],
+ weight_name: Union[str, List[str]],
+ **kwargs,
+ ):
+ """
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+ subfolder (`str` or `List[str]`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally. If a
+ list is passed, it should have the same length as `weight_name`.
+ weight_name (`str` or `List[str]`):
+ The name of the weight file to load. If a list is passed, it should have the same length as
+ `subfolder`.
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
+ argument to `True` will raise an error.
+ """
+
+ # handle the list inputs for multiple IP Adapters
+ if not isinstance(weight_name, list):
+ weight_name = [weight_name]
+
+ if not isinstance(pretrained_model_name_or_path_or_dict, list):
+ pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
+ if len(pretrained_model_name_or_path_or_dict) == 1:
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
+
+ if not isinstance(subfolder, list):
+ subfolder = [subfolder]
+ if len(subfolder) == 1:
+ subfolder = subfolder * len(weight_name)
+
+ if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
+ raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
+
+ if len(weight_name) != len(subfolder):
+ raise ValueError("`weight_name` and `subfolder` must have the same length.")
+
+ # Load the main state dict first.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
+
+ if low_cpu_mem_usage and not is_accelerate_available():
+ low_cpu_mem_usage = False
+ logger.warning(
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
+ " install accelerate\n```\n."
+ )
+
+ if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
+ raise NotImplementedError(
+ "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
+ " `low_cpu_mem_usage=False`."
+ )
+
+ user_agent = {
+ "file_type": "attn_procs_weights",
+ "framework": "pytorch",
+ }
+ state_dicts = []
+ for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
+ pretrained_model_name_or_path_or_dict, weight_name, subfolder
+ ):
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
+ model_file = _get_model_file(
+ pretrained_model_name_or_path_or_dict,
+ weights_name=weight_name,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ local_files_only=local_files_only,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ )
+ if weight_name.endswith(".safetensors"):
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
+ with safe_open(model_file, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ if key.startswith("image_proj."):
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
+ elif key.startswith("ip_adapter."):
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
+ else:
+ state_dict = load_state_dict(model_file)
+ else:
+ state_dict = pretrained_model_name_or_path_or_dict
+
+ keys = list(state_dict.keys())
+ if "image_proj" not in keys and "ip_adapter" not in keys:
+ raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
+
+ state_dicts.append(state_dict)
+
+ unet_name = getattr(self, "unet_name", "unet")
+ unet = getattr(self, unet_name)
+ unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ extra_loras = unet._load_ip_adapter_loras(state_dicts)
+ if extra_loras != {}:
+ if not USE_PEFT_BACKEND:
+ logger.warning("PEFT backend is required to load these weights.")
+ else:
+ # apply the IP Adapter Face ID LoRA weights
+ peft_config = getattr(unet, "peft_config", {})
+ for k, lora in extra_loras.items():
+ if f"faceid_{k}" not in peft_config:
+ self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
+ self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
+
+ def set_ip_adapter_scale(self, scale):
+ """
+ Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
+ granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
+
+ Example:
+
+ ```py
+ # To use original IP-Adapter
+ scale = 1.0
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style block only
+ scale = {
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style+layout blocks
+ scale = {
+ "down": {"block_2": [0.0, 1.0]},
+ "up": {"block_0": [0.0, 1.0, 0.0]},
+ }
+ pipeline.set_ip_adapter_scale(scale)
+
+ # To use style and layout from 2 reference images
+ scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
+ pipeline.set_ip_adapter_scale(scales)
+ ```
+ """
+ unet_name = getattr(self, "unet_name", "unet")
+ unet = getattr(self, unet_name)
+ if not isinstance(scale, list):
+ scale = [scale]
+ scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
+
+ for attn_name, attn_processor in unet.attn_processors.items():
+ if isinstance(
+ attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ ):
+ if len(scale_configs) != len(attn_processor.scale):
+ raise ValueError(
+ f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
+ )
+ elif len(scale_configs) == 1:
+ scale_configs = scale_configs * len(attn_processor.scale)
+ for i, scale_config in enumerate(scale_configs):
+ if isinstance(scale_config, dict):
+ for k, s in scale_config.items():
+ if attn_name.startswith(k):
+ attn_processor.scale[i] = s
+ else:
+ attn_processor.scale[i] = scale_config
+
+ def unload_ip_adapter(self):
+ """
+ Unloads the IP Adapter weights
+
+ Examples:
+
+ ```python
+ >>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
+ >>> pipeline.unload_ip_adapter()
+ >>> ...
+ ```
+ """
+
+ # remove hidden encoder
+ if self.unet is None:
+ return
+
+ self.unet.encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = None
+
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
+ self.unet.text_encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = "text_proj"
+
+ # restore original Unet attention processors layers
+ attn_procs = {}
+ for name, value in self.unet.attn_processors.items():
+ attn_processor_class = (
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
+ )
+ attn_procs[name] = (
+ attn_processor_class
+ if isinstance(
+ value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
+ )
+ else value.__class__()
+ )
+ self.unet.set_attn_processor(attn_procs)
+
+
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@@ -617,6 +865,9 @@ class FluxIPAdapterMixin:
>>> ...
```
"""
+ # TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
+ from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor
+
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
@@ -636,9 +887,9 @@ class FluxIPAdapterMixin:
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
- attn_processor_class = FluxAttnProcessor2_0()
+ attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
- attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
+ attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
diff --git a/src/diffusers/loaders/lora_base.py b/src/diffusers/loaders/lora_base.py
index 16f0d48365..3089086d54 100644
--- a/src/diffusers/loaders/lora_base.py
+++ b/src/diffusers/loaders/lora_base.py
@@ -330,6 +330,8 @@ def _load_lora_into_text_encoder(
hotswap: bool = False,
metadata=None,
):
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
@@ -391,7 +393,9 @@ def _load_lora_into_text_encoder(
adapter_name = get_adapter_name(text_encoder)
#
if prefix is not None and not state_dict:
@@ -433,30 +441,38 @@ def _func_optionally_disable_offloading(_pipeline):
Returns:
tuple:
- A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
+ A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
"""
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
- if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
- if not is_model_cpu_offload:
- is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
- if not is_sequential_cpu_offload:
- is_sequential_cpu_offload = (
- isinstance(component._hf_hook, AlignDevicesHook)
- or hasattr(component._hf_hook, "hooks")
- and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
- )
+ if not isinstance(component, nn.Module):
+ continue
+ is_group_offload = is_group_offload or _is_group_offload_enabled(component)
+ if not hasattr(component, "_hf_hook"):
+ continue
+ is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
+ is_sequential_cpu_offload = is_sequential_cpu_offload or (
+ isinstance(component._hf_hook, AlignDevicesHook)
+ or hasattr(component._hf_hook, "hooks")
+ and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
+ )
- logger.info(
- "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
- )
- if is_sequential_cpu_offload or is_model_cpu_offload:
- remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
+ if is_sequential_cpu_offload or is_model_cpu_offload:
+ logger.info(
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
+ )
+ for _, component in _pipeline.components.items():
+ if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
+ continue
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
- return (is_model_cpu_offload, is_sequential_cpu_offload)
+ return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
class LoraBaseMixin:
@@ -921,6 +937,27 @@ class LoraBaseMixin:
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
+ After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
+ can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
+ GPU before using those LoRA adapters for inference.
+
+ ```python
+ >>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
+ >>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
+ >>> pipe.set_adapters("adapter-1")
+ >>> image_1 = pipe(**kwargs)
+ >>> # switch to adapter-2, offload adapter-1
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-2")
+ >>> image_2 = pipe(**kwargs)
+ >>> # switch back to adapter-1, offload adapter-2
+ >>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
+ >>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
+ >>> pipe.set_adapters("adapter-1")
+ >>> ...
+ ```
+
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
@@ -936,6 +973,10 @@ class LoraBaseMixin:
for module in model.modules():
if isinstance(module, BaseTunerLayer):
for adapter_name in adapter_names:
+ if adapter_name not in module.lora_A:
+ # it is sufficient to check lora_A
+ continue
+
module.lora_A[adapter_name].to(device)
module.lora_B[adapter_name].to(device)
# this is a param, not a module, so device placement is not in-place -> re-assign
@@ -1022,15 +1063,3 @@ class LoraBaseMixin:
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
return _func_optionally_disable_offloading(_pipeline=_pipeline)
-
- @classmethod
- def _fetch_state_dict(cls, *args, **kwargs):
- deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
- deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
- return _fetch_state_dict(*args, **kwargs)
-
- @classmethod
- def _best_guess_weight_name(cls, *args, **kwargs):
- deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
- deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
- return _best_guess_weight_name(*args, **kwargs)
diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py
index 25e06c007f..df3aa6212f 100644
--- a/src/diffusers/loaders/lora_conversion_utils.py
+++ b/src/diffusers/loaders/lora_conversion_utils.py
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
return converted_state_dict
+def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
+ converted_state_dict = {}
+ original_state_dict_keys = list(original_state_dict.keys())
+ num_layers = 19
+ num_single_layers = 38
+ inner_dim = 3072
+ mlp_ratio = 4.0
+
+ # double transformer blocks
+ for i in range(num_layers):
+ block_prefix = f"transformer_blocks.{i}."
+ original_block_prefix = "base_model.model."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norms
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
+ )
+ if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
+ )
+
+ # Q, K, V
+ if lora_key == "lora_A":
+ sample_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
+
+ context_lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
+ [context_lora_weight]
+ )
+ else:
+ sample_q, sample_k, sample_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
+
+ context_q, context_k, context_v = torch.chunk(
+ original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
+ ),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
+
+ if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
+
+ if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
+ context_q_bias, context_k_bias, context_v_bias = torch.chunk(
+ original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
+ 3,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
+ converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
+ converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
+
+ # ff img_mlp
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
+ )
+
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
+ )
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
+ )
+
+ # single transformer blocks
+ for i in range(num_single_layers):
+ block_prefix = f"single_transformer_blocks.{i}."
+
+ for lora_key in ["lora_A", "lora_B"]:
+ # norm.linear <- single_blocks.0.modulation.lin
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
+ )
+
+ # Q, K, V, mlp
+ mlp_hidden_dim = int(inner_dim * mlp_ratio)
+ split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
+
+ if lora_key == "lora_A":
+ lora_weight = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
+ else:
+ q, k, v, mlp = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
+
+ if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
+ q_bias, k_bias, v_bias, mlp_bias = torch.split(
+ original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
+ split_size,
+ dim=0,
+ )
+ converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
+ converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
+ converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
+ converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
+
+ # output projections.
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
+ )
+
+ for lora_key in ["lora_A", "lora_B"]:
+ converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
+ )
+ if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
+ converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
+ f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
+ )
+
+ if len(original_state_dict) > 0:
+ raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
+
+ for key in list(converted_state_dict.keys()):
+ converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
+
+ return converted_state_dict
+
+
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
@@ -1603,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
+ has_time_projection_weight = any(
+ k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
+ )
- diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
- if diff_keys:
- for diff_k in diff_keys:
- param = original_state_dict[diff_k]
- # The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
- # and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
- # to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
- # is okay to ignore because they do not affect the model output in a significant manner.
- threshold = 1.6e-2
- absdiff = param.abs().max() - param.abs().min()
- all_zero = torch.all(param == 0).item()
- all_absdiff_lower_than_threshold = absdiff < threshold
- if all_zero or all_absdiff_lower_than_threshold:
- logger.debug(
- f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
- )
- original_state_dict.pop(diff_k)
+ for key in list(original_state_dict.keys()):
+ if key.endswith((".diff", ".diff_b")) and "norm" in key:
+ # NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
+ # in future if needed and they are not zeroed.
+ original_state_dict.pop(key)
+ logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
+
+ if "time_projection" in key and not has_time_projection_weight:
+ # AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
+ # our lora config adds the time proj lora layers, but we don't have the weights for them.
+ # CausVid lora has the weight keys and the bias keys.
+ original_state_dict.pop(key)
# For the `diff_b` keys, we treat them as lora_bias.
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py
index 4fea005cbc..7fd13176ac 100644
--- a/src/diffusers/loaders/lora_pipeline.py
+++ b/src/diffusers/loaders/lora_pipeline.py
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
+ _convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
return_metadata=return_lora_metadata,
)
+ is_fal_kontext = any("base_model" in k for k in state_dict)
+ if is_fal_kontext:
+ state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
+ return cls._prepare_outputs(
+ state_dict,
+ metadata=metadata,
+ alphas=None,
+ return_alphas=return_alphas,
+ return_metadata=return_lora_metadata,
+ )
+
# For state dicts like
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
keys = list(state_dict.keys())
@@ -5442,6 +5454,404 @@ class WanLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
+class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
+ r"""
+ Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
+ """
+
+ _lora_loadable_modules = ["transformer"]
+ transformer_name = TRANSFORMER_NAME
+
+ @classmethod
+ @validate_hf_hub_args
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
+ def lora_state_dict(
+ cls,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ **kwargs,
+ ):
+ r"""
+ Return state dict for lora weights and the network alphas.
+
+
+
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
+
+ This function is experimental and might change in the future.
+
+
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ Can be either:
+
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
+ the Hub.
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
+ with [`ModelMixin.save_pretrained`].
+ - A [torch state
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
+
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
+ is not used.
+ force_download (`bool`, *optional*, defaults to `False`):
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
+ cached versions if they exist.
+
+ proxies (`Dict[str, str]`, *optional*):
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
+ local_files_only (`bool`, *optional*, defaults to `False`):
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
+ won't be downloaded from the Hub.
+ token (`str` or *bool*, *optional*):
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
+ revision (`str`, *optional*, defaults to `"main"`):
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
+ allowed by Git.
+ subfolder (`str`, *optional*, defaults to `""`):
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
+ return_lora_metadata (`bool`, *optional*, defaults to False):
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
+ """
+ # Load the main state dict first which has the LoRA layers for either of
+ # transformer and text encoder or both.
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", None)
+ token = kwargs.pop("token", None)
+ revision = kwargs.pop("revision", None)
+ subfolder = kwargs.pop("subfolder", None)
+ weight_name = kwargs.pop("weight_name", None)
+ use_safetensors = kwargs.pop("use_safetensors", None)
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
+
+ allow_pickle = False
+ if use_safetensors is None:
+ use_safetensors = True
+ allow_pickle = True
+
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
+
+ state_dict, metadata = _fetch_state_dict(
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
+ weight_name=weight_name,
+ use_safetensors=use_safetensors,
+ local_files_only=local_files_only,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ token=token,
+ revision=revision,
+ subfolder=subfolder,
+ user_agent=user_agent,
+ allow_pickle=allow_pickle,
+ )
+ if any(k.startswith("diffusion_model.") for k in state_dict):
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
+ elif any(k.startswith("lora_unet_") for k in state_dict):
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
+
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
+ if is_dora_scale_present:
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
+ logger.warning(warn_msg)
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
+
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
+ return out
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
+ def _maybe_expand_t2v_lora_for_i2v(
+ cls,
+ transformer: torch.nn.Module,
+ state_dict,
+ ):
+ if transformer.config.image_dim is None:
+ return state_dict
+
+ target_device = transformer.device
+
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
+ has_bias = any(".lora_B.bias" in k for k in state_dict)
+
+ if is_i2v_lora:
+ return state_dict
+
+ for i in range(num_blocks):
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
+ # These keys should exist if the block `i` was part of the T2V LoRA.
+ ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
+ ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
+
+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
+ continue
+
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
+ )
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
+ )
+
+ # If the original LoRA had biases (indicated by has_bias)
+ # AND the specific reference bias key exists for this block.
+
+ ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
+ if has_bias and ref_key_lora_B_bias in state_dict:
+ ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
+ ref_lora_B_bias_tensor,
+ device=target_device,
+ )
+
+ return state_dict
+
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
+ def load_lora_weights(
+ self,
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
+ adapter_name: Optional[str] = None,
+ hotswap: bool = False,
+ **kwargs,
+ ):
+ """
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
+ dict is loaded into `self.transformer`.
+
+ Parameters:
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ kwargs (`dict`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
+ """
+ if not USE_PEFT_BACKEND:
+ raise ValueError("PEFT backend is required for this method.")
+
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # if a dict is passed, copy it instead of modifying it inplace
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
+
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
+ kwargs["return_lora_metadata"] = True
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ state_dict=state_dict,
+ )
+ is_correct_format = all("lora" in key for key in state_dict.keys())
+ if not is_correct_format:
+ raise ValueError("Invalid LoRA checkpoint.")
+
+ self.load_lora_into_transformer(
+ state_dict,
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=self,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
+ def load_lora_into_transformer(
+ cls,
+ state_dict,
+ transformer,
+ adapter_name=None,
+ _pipeline=None,
+ low_cpu_mem_usage=False,
+ hotswap: bool = False,
+ metadata=None,
+ ):
+ """
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
+
+ Parameters:
+ state_dict (`dict`):
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
+ encoder lora layers.
+ transformer (`SkyReelsV2Transformer3DModel`):
+ The Transformer model to load the LoRA layers into.
+ adapter_name (`str`, *optional*):
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
+ `default_{i}` where i is the total number of adapters being loaded.
+ low_cpu_mem_usage (`bool`, *optional*):
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
+ weights.
+ hotswap (`bool`, *optional*):
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
+ metadata (`dict`):
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
+ from the state dict.
+ """
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
+ raise ValueError(
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
+ )
+
+ # Load the layers corresponding to transformer.
+ logger.info(f"Loading {cls.transformer_name}.")
+ transformer.load_lora_adapter(
+ state_dict,
+ network_alphas=None,
+ adapter_name=adapter_name,
+ metadata=metadata,
+ _pipeline=_pipeline,
+ low_cpu_mem_usage=low_cpu_mem_usage,
+ hotswap=hotswap,
+ )
+
+ @classmethod
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
+ def save_lora_weights(
+ cls,
+ save_directory: Union[str, os.PathLike],
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
+ is_main_process: bool = True,
+ weight_name: str = None,
+ save_function: Callable = None,
+ safe_serialization: bool = True,
+ transformer_lora_adapter_metadata: Optional[dict] = None,
+ ):
+ r"""
+ Save the LoRA parameters corresponding to the transformer.
+
+ Arguments:
+ save_directory (`str` or `os.PathLike`):
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
+ State dict of the LoRA layers corresponding to the `transformer`.
+ is_main_process (`bool`, *optional*, defaults to `True`):
+ Whether the process calling this is the main process or not. Useful during distributed training and you
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
+ process to avoid race conditions.
+ save_function (`Callable`):
+ The function to use to save the state dictionary. Useful during distributed training when you need to
+ replace `torch.save` with another method. Can be configured with the environment variable
+ `DIFFUSERS_SAVE_MODE`.
+ safe_serialization (`bool`, *optional*, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
+ transformer_lora_adapter_metadata:
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
+ """
+ state_dict = {}
+ lora_adapter_metadata = {}
+
+ if not transformer_lora_layers:
+ raise ValueError("You must pass `transformer_lora_layers`.")
+
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
+
+ if transformer_lora_adapter_metadata is not None:
+ lora_adapter_metadata.update(
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
+ )
+
+ # Save the model
+ cls.write_lora_layers(
+ state_dict=state_dict,
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ weight_name=weight_name,
+ save_function=save_function,
+ safe_serialization=safe_serialization,
+ lora_adapter_metadata=lora_adapter_metadata,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
+ def fuse_lora(
+ self,
+ components: List[str] = ["transformer"],
+ lora_scale: float = 1.0,
+ safe_fusing: bool = False,
+ adapter_names: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ r"""
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
+ lora_scale (`float`, defaults to 1.0):
+ Controls how much to influence the outputs with the LoRA parameters.
+ safe_fusing (`bool`, defaults to `False`):
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
+ adapter_names (`List[str]`, *optional*):
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
+
+ Example:
+
+ ```py
+ from diffusers import DiffusionPipeline
+ import torch
+
+ pipeline = DiffusionPipeline.from_pretrained(
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
+ ).to("cuda")
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
+ pipeline.fuse_lora(lora_scale=0.7)
+ ```
+ """
+ super().fuse_lora(
+ components=components,
+ lora_scale=lora_scale,
+ safe_fusing=safe_fusing,
+ adapter_names=adapter_names,
+ **kwargs,
+ )
+
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
+ r"""
+ Reverses the effect of
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
+
+
+
+ This is an experimental API.
+
+
+
+ Args:
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
+ """
+ super().unfuse_lora(components=components, **kwargs)
+
+
class CogView4LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
diff --git a/src/diffusers/loaders/peft.py b/src/diffusers/loaders/peft.py
index 3436230713..393c8ee27d 100644
--- a/src/diffusers/loaders/peft.py
+++ b/src/diffusers/loaders/peft.py
@@ -163,6 +163,8 @@ class PeftAdapterMixin:
from peft import inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -243,20 +245,29 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
- # create LoraConfig
- lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
-
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
+ # create LoraConfig
+ lora_config = _create_lora_config(
+ state_dict,
+ network_alphas,
+ metadata,
+ rank,
+ model_state_dict=self.state_dict(),
+ adapter_name=adapter_name,
+ )
+
# =", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -347,6 +358,10 @@ class PeftAdapterMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
+ elif is_group_offload:
+ for component in _pipeline.components.values():
+ if isinstance(component, torch.nn.Module):
+ _maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
if prefix is not None and not state_dict:
@@ -681,11 +696,16 @@ class PeftAdapterMixin:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unload_lora()`.")
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
from ..utils import recurse_remove_peft_layers
recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"):
del self.peft_config
+ if hasattr(self, "_hf_peft_config_loaded"):
+ self._hf_peft_config_loaded = None
+
+ _maybe_remove_and_reapply_group_offloading(self)
def disable_lora(self):
"""
diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py
index 0c6f3cda66..76fefc1260 100644
--- a/src/diffusers/loaders/single_file_model.py
+++ b/src/diffusers/loaders/single_file_model.py
@@ -24,6 +24,7 @@ from typing_extensions import Self
from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
+from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
convert_animatediff_checkpoint_to_diffusers,
@@ -31,6 +32,7 @@ from .single_file_utils import (
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
+ convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
@@ -135,6 +137,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "WanVACETransformer3DModel": {
+ "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
+ "default_subfolder": "transformer",
+ },
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
@@ -143,6 +149,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
+ "CosmosTransformer3DModel": {
+ "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
+ "default_subfolder": "transformer",
+ },
}
@@ -421,6 +431,7 @@ class FromOriginalModelMixin:
keep_in_fp32_modules=keep_in_fp32_modules,
unexpected_keys=unexpected_keys,
)
+ empty_device_cache()
else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py
index d8d183304e..a804ea80a9 100644
--- a/src/diffusers/loaders/single_file_utils.py
+++ b/src/diffusers/loaders/single_file_utils.py
@@ -46,6 +46,7 @@ from ..utils import (
)
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
from ..utils.hub_utils import _get_model_file
+from ..utils.torch_utils import empty_device_cache
if is_transformers_available():
@@ -126,7 +127,18 @@ CHECKPOINT_KEY_NAMES = {
],
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
+ "wan_vace": "vace_blocks.0.after_proj.bias",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
+ "cosmos-1.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
+ "net.extra_pos_embedder.pos_emb_h",
+ ],
+ "cosmos-2.0": [
+ "net.x_embedder.proj.1.weight",
+ "net.blocks.0.self_attn.q_proj.weight",
+ "net.pos_embedder.dim_spatial_range",
+ ],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -192,7 +204,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
+ "wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
+ "wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
+ "cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
+ "cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
+ "cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
+ "cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
+ "cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
+ "cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
+ "cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
+ "cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
}
# Use to configure model sample size when original config is provided
@@ -698,17 +720,44 @@ def infer_diffusers_model_type(checkpoint):
else:
target_key = "patch_embedding.weight"
- if checkpoint[target_key].shape[0] == 1536:
+ if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
+ if checkpoint[target_key].shape[0] == 1536:
+ model_type = "wan-vace-1.3B"
+ elif checkpoint[target_key].shape[0] == 5120:
+ model_type = "wan-vace-14B"
+
+ elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
model_type = "wan-t2v-14B"
else:
model_type = "wan-i2v-14B"
+
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
model_type = "wan-t2v-14B"
+
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
model_type = "hidream"
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
+
+ elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
+ x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
+ if x_embedder_shape[1] == 68:
+ model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
+ elif x_embedder_shape[1] == 72:
+ model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
+ else:
+ raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
+
else:
model_type = "v1"
@@ -1641,6 +1690,7 @@ def create_diffusers_clip_model_from_ldm(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint, strict=False)
@@ -2100,6 +2150,7 @@ def create_diffusers_t5_model_from_checkpoint(
if is_accelerate_available():
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
+ empty_device_cache()
else:
model.load_state_dict(diffusers_format_checkpoint)
@@ -3093,6 +3144,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
+ # For the VACE model
+ "before_proj": "proj_in",
+ "after_proj": "proj_out",
}
for key in list(checkpoint.keys()):
@@ -3479,3 +3533,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
return converted_state_dict
+
+
+def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
+ converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
+
+ def remove_keys_(key: str, state_dict):
+ state_dict.pop(key)
+
+ def rename_transformer_blocks_(key: str, state_dict):
+ block_index = int(key.split(".")[1].removeprefix("block"))
+ new_key = key
+ old_prefix = f"blocks.block{block_index}"
+ new_prefix = f"transformer_blocks.{block_index}"
+ new_key = new_prefix + new_key.removeprefix(old_prefix)
+ state_dict[new_key] = state_dict.pop(key)
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "affline_norm": "time_embed.norm",
+ ".blocks.0.block.attn": ".attn1",
+ ".blocks.1.block.attn": ".attn2",
+ ".blocks.2.block": ".ff",
+ ".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
+ ".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
+ ".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
+ ".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
+ ".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
+ ".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
+ "to_q.0": "to_q",
+ "to_q.1": "norm_q",
+ "to_k.0": "to_k",
+ "to_k.1": "norm_k",
+ "to_v.0": "to_v",
+ "layer1": "net.0.proj",
+ "layer2": "net.2",
+ "proj.1": "proj",
+ "x_embedder": "patch_embed",
+ "extra_pos_embedder": "learnable_pos_embed",
+ "final_layer.adaLN_modulation.1": "norm_out.linear_1",
+ "final_layer.adaLN_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
+ "blocks.block": rename_transformer_blocks_,
+ "logvar.0.freqs": remove_keys_,
+ "logvar.0.phases": remove_keys_,
+ "logvar.1.weight": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ }
+
+ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
+ "t_embedder.1": "time_embed.t_embedder",
+ "t_embedding_norm": "time_embed.norm",
+ "blocks": "transformer_blocks",
+ "adaln_modulation_self_attn.1": "norm1.linear_1",
+ "adaln_modulation_self_attn.2": "norm1.linear_2",
+ "adaln_modulation_cross_attn.1": "norm2.linear_1",
+ "adaln_modulation_cross_attn.2": "norm2.linear_2",
+ "adaln_modulation_mlp.1": "norm3.linear_1",
+ "adaln_modulation_mlp.2": "norm3.linear_2",
+ "self_attn": "attn1",
+ "cross_attn": "attn2",
+ "q_proj": "to_q",
+ "k_proj": "to_k",
+ "v_proj": "to_v",
+ "output_proj": "to_out.0",
+ "q_norm": "norm_q",
+ "k_norm": "norm_k",
+ "mlp.layer1": "ff.net.0.proj",
+ "mlp.layer2": "ff.net.2",
+ "x_embedder.proj.1": "patch_embed.proj",
+ "final_layer.adaln_modulation.1": "norm_out.linear_1",
+ "final_layer.adaln_modulation.2": "norm_out.linear_2",
+ "final_layer.linear": "proj_out",
+ }
+
+ TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
+ "accum_video_sample_counter": remove_keys_,
+ "accum_image_sample_counter": remove_keys_,
+ "accum_iteration": remove_keys_,
+ "accum_train_in_hours": remove_keys_,
+ "pos_embedder.seq": remove_keys_,
+ "pos_embedder.dim_spatial_range": remove_keys_,
+ "pos_embedder.dim_temporal_range": remove_keys_,
+ "_extra_state": remove_keys_,
+ }
+
+ PREFIX_KEY = "net."
+ if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
+ else:
+ TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
+ TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ new_key = key[:]
+ if new_key.startswith(PREFIX_KEY):
+ new_key = new_key.removeprefix(PREFIX_KEY)
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
+ new_key = new_key.replace(replace_key, rename_key)
+ converted_state_dict[new_key] = converted_state_dict.pop(key)
+
+ state_dict_keys = list(converted_state_dict.keys())
+ for key in state_dict_keys:
+ for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
+ if special_key not in key:
+ continue
+ handler_fn_inplace(key, converted_state_dict)
+
+ return converted_state_dict
diff --git a/src/diffusers/loaders/transformer_flux.py b/src/diffusers/loaders/transformer_flux.py
index c7d81a8bae..ced81960fa 100644
--- a/src/diffusers/loaders/transformer_flux.py
+++ b/src/diffusers/loaders/transformer_flux.py
@@ -18,11 +18,8 @@ from ..models.embeddings import (
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
-from ..utils import (
- is_accelerate_available,
- is_torch_version,
- logging,
-)
+from ..utils import is_accelerate_available, is_torch_version, logging
+from ..utils.torch_utils import empty_device_cache
if is_accelerate_available():
@@ -84,13 +81,12 @@ class FluxTransformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
- from ..models.attention_processor import (
- FluxIPAdapterJointAttnProcessor2_0,
- )
+ from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
if low_cpu_mem_usage:
if is_accelerate_available():
@@ -122,7 +118,7 @@ class FluxTransformer2DLoadersMixin:
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
- attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
+ attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
@@ -158,6 +154,8 @@ class FluxTransformer2DLoadersMixin:
key_id += 1
+ empty_device_cache()
+
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
diff --git a/src/diffusers/loaders/transformer_sd3.py b/src/diffusers/loaders/transformer_sd3.py
index c58d3280cf..1bc3a9c7a8 100644
--- a/src/diffusers/loaders/transformer_sd3.py
+++ b/src/diffusers/loaders/transformer_sd3.py
@@ -18,6 +18,7 @@ from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..utils import is_accelerate_available, is_torch_version, logging
+from ..utils.torch_utils import empty_device_cache
logger = logging.get_logger(__name__)
@@ -80,6 +81,8 @@ class SD3Transformer2DLoadersMixin:
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
)
+ empty_device_cache()
+
return attn_procs
def _convert_ip_adapter_image_proj_to_diffusers(
@@ -147,6 +150,7 @@ class SD3Transformer2DLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_proj
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index 68be841191..1d698e5a8b 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -43,6 +43,7 @@ from ..utils import (
is_torch_version,
logging,
)
+from ..utils.torch_utils import empty_device_cache
from .lora_base import _func_optionally_disable_offloading
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .utils import AttnProcsLayers
@@ -131,6 +132,8 @@ class UNet2DConditionLoadersMixin:
)
```
"""
+ from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
+
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -203,6 +206,7 @@ class UNet2DConditionLoadersMixin:
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
if is_lora:
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
@@ -211,7 +215,7 @@ class UNet2DConditionLoadersMixin:
if is_custom_diffusion:
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
elif is_lora:
- is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
@@ -230,7 +234,9 @@ class UNet2DConditionLoadersMixin:
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
if is_custom_diffusion and _pipeline is not None:
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
+ _pipeline=_pipeline
+ )
# only custom diffusion needs to set attn processors
self.set_attn_processor(attn_processors)
@@ -241,6 +247,10 @@ class UNet2DConditionLoadersMixin:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
+ elif is_group_offload:
+ for component in _pipeline.components.values():
+ if isinstance(component, torch.nn.Module):
+ _maybe_remove_and_reapply_group_offloading(component)
# Unsafe code />
def _process_custom_diffusion(self, state_dict):
@@ -307,6 +317,7 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False
is_sequential_cpu_offload = False
+ is_group_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
if len(state_dict_to_be_used) > 0:
@@ -356,7 +367,9 @@ class UNet2DConditionLoadersMixin:
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
- is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
+ is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
+ _pipeline
+ )
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
@@ -389,7 +402,7 @@ class UNet2DConditionLoadersMixin:
if warn_msg:
logger.warning(warn_msg)
- return is_model_cpu_offload, is_sequential_cpu_offload
+ return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
@@ -741,6 +754,7 @@ class UNet2DConditionLoadersMixin:
else:
device_map = {"": self.device}
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
+ empty_device_cache()
return image_projection
@@ -838,6 +852,8 @@ class UNet2DConditionLoadersMixin:
key_id += 2
+ empty_device_cache()
+
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
diff --git a/src/diffusers/loaders/unet_loader_utils.py b/src/diffusers/loaders/unet_loader_utils.py
index 274665204d..d5b0e83cbd 100644
--- a/src/diffusers/loaders/unet_loader_utils.py
+++ b/src/diffusers/loaders/unet_loader_utils.py
@@ -14,6 +14,8 @@
import copy
from typing import TYPE_CHECKING, Dict, List, Union
+from torch import nn
+
from ..utils import logging
@@ -52,7 +54,7 @@ def _maybe_expand_lora_scales(
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
- unet.state_dict(),
+ model=unet,
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
@@ -65,7 +67,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales: Union[float, Dict],
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
- state_dict: None,
+ model: nn.Module,
default_scale: float = 1.0,
):
"""
@@ -154,6 +156,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
del scales[updown]
+ state_dict = model.state_dict()
for layer in scales.keys():
if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
raise ValueError(
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 73903a6274..cd1df3667a 100755
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -26,6 +26,7 @@ _import_structure = {}
if is_torch_available():
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
+ _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
_import_structure["auto_model"] = ["AutoModel"]
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"]
@@ -88,6 +89,7 @@ if is_torch_available():
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
+ _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
_import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"]
@@ -111,6 +113,7 @@ if is_flax_available():
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available():
from .adapter import MultiAdapter, T2IAdapter
+ from .attention_dispatch import AttentionBackendName, attention_backend
from .auto_model import AutoModel
from .autoencoders import (
AsymmetricAutoencoderKL,
@@ -176,6 +179,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PriorTransformer,
SanaTransformer2DModel,
SD3Transformer2DModel,
+ SkyReelsV2Transformer3DModel,
StableAudioDiTModel,
T5FilmDecoder,
Transformer2DModel,
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index ae51d3ab13..c720b37955 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -11,23 +11,504 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Optional, Tuple
+
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
+import torch.nn as nn
import torch.nn.functional as F
-from torch import nn
from ..utils import deprecate, logging
+from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
-from .attention_processor import Attention, JointAttnProcessor2_0
+from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
from .embeddings import SinusoidalPositionalEmbedding
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
+if is_xformers_available():
+ import xformers as xops
+else:
+ xops = None
+
+
logger = logging.get_logger(__name__)
+class AttentionMixin:
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor()
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
+ are fused. For cross-attention modules, key and value projection matrices are fused.
+ """
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.fuse_projections()
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ for module in self.modules():
+ if isinstance(module, AttentionModuleMixin):
+ module.unfuse_projections()
+
+
+class AttentionModuleMixin:
+ _default_processor_cls = None
+ _available_processors = []
+ fused_projections = False
+
+ def set_processor(self, processor: AttentionProcessor) -> None:
+ """
+ Set the attention processor to use.
+
+ Args:
+ processor (`AttnProcessor`):
+ The attention processor to use.
+ """
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
+ # pop `processor` from `self._modules`
+ if (
+ hasattr(self, "processor")
+ and isinstance(self.processor, torch.nn.Module)
+ and not isinstance(processor, torch.nn.Module)
+ ):
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
+ self._modules.pop("processor")
+
+ self.processor = processor
+
+ def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
+ """
+ Get the attention processor in use.
+
+ Args:
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
+ Set to `True` to return the deprecated LoRA attention processor.
+
+ Returns:
+ "AttentionProcessor": The attention processor in use.
+ """
+ if not return_deprecated_lora:
+ return self.processor
+
+ def set_attention_backend(self, backend: str):
+ from .attention_dispatch import AttentionBackendName
+
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+
+ backend = AttentionBackendName(backend.lower())
+ self.processor._attention_backend = backend
+
+ def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
+ """
+ Set whether to use NPU flash attention from `torch_npu` or not.
+
+ Args:
+ use_npu_flash_attention (`bool`): Whether to use NPU flash attention or not.
+ """
+
+ if use_npu_flash_attention:
+ if not is_torch_npu_available():
+ raise ImportError("torch_npu is not available")
+
+ self.set_attention_backend("_native_npu")
+
+ def set_use_xla_flash_attention(
+ self,
+ use_xla_flash_attention: bool,
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
+ is_flux=False,
+ ) -> None:
+ """
+ Set whether to use XLA flash attention from `torch_xla` or not.
+
+ Args:
+ use_xla_flash_attention (`bool`):
+ Whether to use pallas flash attention kernel from `torch_xla` or not.
+ partition_spec (`Tuple[]`, *optional*):
+ Specify the partition specification if using SPMD. Otherwise None.
+ is_flux (`bool`, *optional*, defaults to `False`):
+ Whether the model is a Flux model.
+ """
+ if use_xla_flash_attention:
+ if not is_torch_xla_available():
+ raise ImportError("torch_xla is not available")
+
+ self.set_attention_backend("_native_xla")
+
+ def set_use_memory_efficient_attention_xformers(
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
+ ) -> None:
+ """
+ Set whether to use memory efficient attention from `xformers` or not.
+
+ Args:
+ use_memory_efficient_attention_xformers (`bool`):
+ Whether to use memory efficient attention from `xformers` or not.
+ attention_op (`Callable`, *optional*):
+ The attention operation to use. Defaults to `None` which uses the default attention operation from
+ `xformers`.
+ """
+ if use_memory_efficient_attention_xformers:
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
+ " only available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ if is_xformers_available():
+ dtype = None
+ if attention_op is not None:
+ op_fw, op_bw = attention_op
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
+ _ = xops.memory_efficient_attention(q, q, q)
+ except Exception as e:
+ raise e
+
+ self.set_attention_backend("xformers")
+
+ @torch.no_grad()
+ def fuse_projections(self):
+ """
+ Fuse the query, key, and value projections into a single projection for efficiency.
+ """
+ # Skip if already fused
+ if getattr(self, "fused_projections", False):
+ return
+
+ device = self.to_q.weight.data.device
+ dtype = self.to_q.weight.data.dtype
+
+ if hasattr(self, "is_cross_attention") and self.is_cross_attention:
+ # Fuse cross-attention key-value projections
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_kv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_kv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ self.to_kv.bias.copy_(concatenated_bias)
+ else:
+ # Fuse self-attention projections
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_qkv = nn.Linear(in_features, out_features, bias=self.use_bias, device=device, dtype=dtype)
+ self.to_qkv.weight.copy_(concatenated_weights)
+ if hasattr(self, "use_bias") and self.use_bias:
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ self.to_qkv.bias.copy_(concatenated_bias)
+
+ # Handle added projections for models like SD3, Flux, etc.
+ if (
+ getattr(self, "add_q_proj", None) is not None
+ and getattr(self, "add_k_proj", None) is not None
+ and getattr(self, "add_v_proj", None) is not None
+ ):
+ concatenated_weights = torch.cat(
+ [self.add_q_proj.weight.data, self.add_k_proj.weight.data, self.add_v_proj.weight.data]
+ )
+ in_features = concatenated_weights.shape[1]
+ out_features = concatenated_weights.shape[0]
+
+ self.to_added_qkv = nn.Linear(
+ in_features, out_features, bias=self.added_proj_bias, device=device, dtype=dtype
+ )
+ self.to_added_qkv.weight.copy_(concatenated_weights)
+ if self.added_proj_bias:
+ concatenated_bias = torch.cat(
+ [self.add_q_proj.bias.data, self.add_k_proj.bias.data, self.add_v_proj.bias.data]
+ )
+ self.to_added_qkv.bias.copy_(concatenated_bias)
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ """
+ Unfuse the query, key, and value projections back to separate projections.
+ """
+ # Skip if not fused
+ if not getattr(self, "fused_projections", False):
+ return
+
+ # Remove fused projection layers
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+
+ if hasattr(self, "to_added_qkv"):
+ delattr(self, "to_added_qkv")
+
+ self.fused_projections = False
+
+ def set_attention_slice(self, slice_size: int) -> None:
+ """
+ Set the slice size for attention computation.
+
+ Args:
+ slice_size (`int`):
+ The slice size for attention computation.
+ """
+ if hasattr(self, "sliceable_head_dim") and slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ processor = None
+
+ # Try to get a compatible processor for sliced attention
+ if slice_size is not None:
+ processor = self._get_compatible_processor("sliced")
+
+ # If no processor was found or slice_size is None, use default processor
+ if processor is None:
+ processor = self.default_processor_cls()
+
+ self.set_processor(processor)
+
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
+ """
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ batch_size, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
+ """
+ Reshape the tensor for multi-head attention processing.
+
+ Args:
+ tensor (`torch.Tensor`): The tensor to reshape.
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor.
+
+ Returns:
+ `torch.Tensor`: The reshaped tensor.
+ """
+ head_size = self.heads
+ if tensor.ndim == 3:
+ batch_size, seq_len, dim = tensor.shape
+ extra_dim = 1
+ else:
+ batch_size, extra_dim, seq_len, dim = tensor.shape
+ tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3)
+
+ if out_dim == 3:
+ tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)
+
+ return tensor
+
+ def get_attention_scores(
+ self, query: torch.Tensor, key: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Compute the attention scores.
+
+ Args:
+ query (`torch.Tensor`): The query tensor.
+ key (`torch.Tensor`): The key tensor.
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use.
+
+ Returns:
+ `torch.Tensor`: The attention probabilities/scores.
+ """
+ dtype = query.dtype
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if attention_mask is None:
+ baddbmm_input = torch.empty(
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
+ )
+ beta = 0
+ else:
+ baddbmm_input = attention_mask
+ beta = 1
+
+ attention_scores = torch.baddbmm(
+ baddbmm_input,
+ query,
+ key.transpose(-1, -2),
+ beta=beta,
+ alpha=self.scale,
+ )
+ del baddbmm_input
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ del attention_scores
+
+ attention_probs = attention_probs.to(dtype)
+
+ return attention_probs
+
+ def prepare_attention_mask(
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
+ ) -> torch.Tensor:
+ """
+ Prepare the attention mask for the attention computation.
+
+ Args:
+ attention_mask (`torch.Tensor`): The attention mask to prepare.
+ target_length (`int`): The target length of the attention mask.
+ batch_size (`int`): The batch size for repeating the attention mask.
+ out_dim (`int`, *optional*, defaults to `3`): Output dimension.
+
+ Returns:
+ `torch.Tensor`: The prepared attention mask.
+ """
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
+ # Instead, we can manually construct the padding tensor.
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
+ # remaining_length: int = target_length - current_length
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+ def norm_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
+ """
+ Normalize the encoder hidden states.
+
+ Args:
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
+
+ Returns:
+ `torch.Tensor`: The normalized encoder hidden states.
+ """
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
+ if isinstance(self.norm_cross, nn.LayerNorm):
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ elif isinstance(self.norm_cross, nn.GroupNorm):
+ # Group norm norms along the channels dimension and expects
+ # input to be in the shape of (N, C, *). In this case, we want
+ # to norm along the hidden dimension, so we need to move
+ # (batch_size, sequence_length, hidden_size) ->
+ # (batch_size, hidden_size, sequence_length)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
+ else:
+ assert False
+
+ return encoder_hidden_states
+
+
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0:
diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py
new file mode 100644
index 0000000000..c00ec7dd6e
--- /dev/null
+++ b/src/diffusers/models/attention_dispatch.py
@@ -0,0 +1,1218 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import functools
+import inspect
+import math
+from enum import Enum
+from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
+
+import torch
+
+from ..utils import (
+ get_logger,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
+ is_sageattention_available,
+ is_sageattention_version,
+ is_torch_npu_available,
+ is_torch_version,
+ is_torch_xla_available,
+ is_torch_xla_version,
+ is_xformers_available,
+ is_xformers_version,
+)
+from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
+
+
+_REQUIRED_FLASH_VERSION = "2.6.3"
+_REQUIRED_SAGE_VERSION = "2.1.1"
+_REQUIRED_FLEX_VERSION = "2.5.0"
+_REQUIRED_XLA_VERSION = "2.2"
+_REQUIRED_XFORMERS_VERSION = "0.0.29"
+
+_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
+_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
+_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
+_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
+_CAN_USE_NPU_ATTN = is_torch_npu_available()
+_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
+_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
+
+
+if _CAN_USE_FLASH_ATTN:
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
+else:
+ flash_attn_func = None
+ flash_attn_varlen_func = None
+
+
+if _CAN_USE_FLASH_ATTN_3:
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+else:
+ flash_attn_3_func = None
+ flash_attn_3_varlen_func = None
+
+
+if _CAN_USE_SAGE_ATTN:
+ from sageattention import (
+ sageattn,
+ sageattn_qk_int8_pv_fp8_cuda,
+ sageattn_qk_int8_pv_fp8_cuda_sm90,
+ sageattn_qk_int8_pv_fp16_cuda,
+ sageattn_qk_int8_pv_fp16_triton,
+ sageattn_varlen,
+ )
+else:
+ sageattn = None
+ sageattn_qk_int8_pv_fp16_cuda = None
+ sageattn_qk_int8_pv_fp16_triton = None
+ sageattn_qk_int8_pv_fp8_cuda = None
+ sageattn_qk_int8_pv_fp8_cuda_sm90 = None
+ sageattn_varlen = None
+
+
+if _CAN_USE_FLEX_ATTN:
+ # We cannot import the flex_attention function from the package directly because it is expected (from the
+ # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
+ # compiled function.
+ import torch.nn.attention.flex_attention as flex_attention
+
+
+if _CAN_USE_NPU_ATTN:
+ from torch_npu import npu_fusion_attention
+else:
+ npu_fusion_attention = None
+
+
+if _CAN_USE_XLA_ATTN:
+ from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
+else:
+ xla_flash_attention = None
+
+
+if _CAN_USE_XFORMERS_ATTN:
+ import xformers.ops as xops
+else:
+ xops = None
+
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+
+# TODO(aryan): Add support for the following:
+# - Sage Attention++
+# - block sparse, radial and other attention methods
+# - CP with sage attention, flex, xformers, other missing backends
+# - Add support for normal and CP training with backends that don't support it yet
+
+_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
+_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
+_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
+
+
+class AttentionBackendName(str, Enum):
+ # EAGER = "eager"
+
+ # `flash-attn`
+ FLASH = "flash"
+ FLASH_VARLEN = "flash_varlen"
+ _FLASH_3 = "_flash_3"
+ _FLASH_VARLEN_3 = "_flash_varlen_3"
+
+ # PyTorch native
+ FLEX = "flex"
+ NATIVE = "native"
+ _NATIVE_CUDNN = "_native_cudnn"
+ _NATIVE_EFFICIENT = "_native_efficient"
+ _NATIVE_FLASH = "_native_flash"
+ _NATIVE_MATH = "_native_math"
+ _NATIVE_NPU = "_native_npu"
+ _NATIVE_XLA = "_native_xla"
+
+ # `sageattention`
+ SAGE = "sage"
+ SAGE_VARLEN = "sage_varlen"
+ _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
+ _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
+ _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda"
+ _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton"
+ # TODO: let's not add support for Sparge Attention now because it requires tuning per model
+ # We can look into supporting something "autotune"-ing in the future
+ # SPARGE = "sparge"
+
+ # `xformers`
+ XFORMERS = "xformers"
+
+
+class _AttentionBackendRegistry:
+ _backends = {}
+ _constraints = {}
+ _supported_arg_names = {}
+ _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
+ _checks_enabled = DIFFUSERS_ATTN_CHECKS
+
+ @classmethod
+ def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None):
+ logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}")
+
+ def decorator(func):
+ cls._backends[backend] = func
+ cls._constraints[backend] = constraints or []
+ cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
+ return func
+
+ return decorator
+
+ @classmethod
+ def get_active_backend(cls):
+ return cls._active_backend, cls._backends[cls._active_backend]
+
+ @classmethod
+ def list_backends(cls):
+ return list(cls._backends.keys())
+
+
+@contextlib.contextmanager
+def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
+ """
+ Context manager to set the active attention backend.
+ """
+ if backend not in _AttentionBackendRegistry._backends:
+ raise ValueError(f"Backend {backend} is not registered.")
+
+ backend = AttentionBackendName(backend)
+ _check_attention_backend_requirements(backend)
+
+ old_backend = _AttentionBackendRegistry._active_backend
+ _AttentionBackendRegistry._active_backend = backend
+
+ try:
+ yield
+ finally:
+ _AttentionBackendRegistry._active_backend = old_backend
+
+
+def dispatch_attention_fn(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ *,
+ backend: Optional[AttentionBackendName] = None,
+) -> torch.Tensor:
+ attention_kwargs = attention_kwargs or {}
+
+ if backend is None:
+ # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment
+ # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager
+ backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend()
+ else:
+ backend_name = AttentionBackendName(backend)
+ backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
+
+ kwargs = {
+ "query": query,
+ "key": key,
+ "value": value,
+ "attn_mask": attn_mask,
+ "dropout_p": dropout_p,
+ "is_causal": is_causal,
+ "scale": scale,
+ **attention_kwargs,
+ }
+ if is_torch_version(">=", "2.5.0"):
+ kwargs["enable_gqa"] = enable_gqa
+
+ if _AttentionBackendRegistry._checks_enabled:
+ removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name])
+ if removed_kwargs:
+ logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.")
+ for check in _AttentionBackendRegistry._constraints.get(backend_name):
+ check(**kwargs)
+
+ kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
+ return backend_fn(**kwargs)
+
+
+# ===== Checks =====
+# A list of very simple functions to catch common errors quickly when debugging.
+
+
+def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None:
+ if attn_mask is not None and is_causal:
+ raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.")
+
+
+def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.device != key.device or query.device != value.device:
+ raise ValueError("Query, key, and value must be on the same device.")
+ if query.dtype != key.dtype or query.dtype != value.dtype:
+ raise ValueError("Query, key, and value must have the same dtype.")
+
+
+def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device(query, key, value)
+ if query.device.type != "cuda":
+ raise ValueError("Query, key, and value must be on a CUDA device.")
+
+
+def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable:
+ def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_device_cuda(query, key, value)
+ if torch.cuda.get_device_capability(query.device) < (major, minor):
+ raise ValueError(
+ f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}."
+ )
+
+ return check_device_cuda
+
+
+def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ if query.dtype != key.dtype:
+ raise ValueError("Query and key must have the same dtype.")
+ if query.dtype != value.dtype:
+ raise ValueError("Query and value must have the same dtype.")
+
+
+def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None:
+ _check_qkv_dtype_match(query, key, value)
+ if query.dtype not in (torch.bfloat16, torch.float16):
+ raise ValueError("Query, key, and value must be either bfloat16 or float16.")
+
+
+def _check_shape(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+) -> None:
+ if query.shape[-1] != key.shape[-1]:
+ raise ValueError("Query and key must have the same last dimension.")
+ if query.shape[-2] != value.shape[-2]:
+ raise ValueError("Query and value must have the same second to last dimension.")
+ if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]:
+ raise ValueError("Attention mask must match the key's second to last dimension.")
+
+
+# ===== Helper functions =====
+
+
+def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
+ if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
+ if not _CAN_USE_FLASH_ATTN:
+ raise RuntimeError(
+ f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
+ )
+
+ elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
+ if not _CAN_USE_FLASH_ATTN_3:
+ raise RuntimeError(
+ f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
+ )
+
+ elif backend in [
+ AttentionBackendName.SAGE,
+ AttentionBackendName.SAGE_VARLEN,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ ]:
+ if not _CAN_USE_SAGE_ATTN:
+ raise RuntimeError(
+ f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
+ )
+
+ elif backend == AttentionBackendName.FLEX:
+ if not _CAN_USE_FLEX_ATTN:
+ raise RuntimeError(
+ f"Flex Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch>=2.5.0`."
+ )
+
+ elif backend == AttentionBackendName._NATIVE_NPU:
+ if not _CAN_USE_NPU_ATTN:
+ raise RuntimeError(
+ f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`."
+ )
+
+ elif backend == AttentionBackendName._NATIVE_XLA:
+ if not _CAN_USE_XLA_ATTN:
+ raise RuntimeError(
+ f"XLA Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_xla>={_REQUIRED_XLA_VERSION}`."
+ )
+
+ elif backend == AttentionBackendName.XFORMERS:
+ if not _CAN_USE_XFORMERS_ATTN:
+ raise RuntimeError(
+ f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`."
+ )
+
+
+@functools.lru_cache(maxsize=128)
+def _prepare_for_flash_attn_or_sage_varlen_without_mask(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen_with_mask(
+ batch_size: int,
+ seq_len_q: int,
+ attn_mask: torch.Tensor,
+ device: Optional[torch.device] = None,
+):
+ seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
+ seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
+ cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
+ cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
+ cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
+ max_seqlen_q = seqlens_q.max().item()
+ max_seqlen_k = seqlens_k.max().item()
+ return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
+
+
+def _prepare_for_flash_attn_or_sage_varlen(
+ batch_size: int,
+ seq_len_q: int,
+ seq_len_kv: int,
+ attn_mask: Optional[torch.Tensor] = None,
+ device: Optional[torch.device] = None,
+) -> None:
+ if attn_mask is None:
+ return _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, device)
+ return _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, device)
+
+
+def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor:
+ """
+ Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in
+ FlashAttention/Sage varlen.
+
+ Supports 1D to 4D shapes and common broadcasting patterns.
+ """
+ if attn_mask.dtype != torch.bool:
+ raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.")
+
+ if attn_mask.ndim == 1:
+ # [seq_len_k] -> broadcast across batch
+ attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 2:
+ # [batch_size, seq_len_k]. Maybe broadcast across batch
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 3:
+ # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension
+ # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen.
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask."
+ )
+ attn_mask = attn_mask.any(dim=1)
+ attn_mask = attn_mask.expand(batch_size, seq_len_k)
+
+ elif attn_mask.ndim == 4:
+ # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions
+ if attn_mask.size(0) not in [1, batch_size]:
+ raise ValueError(
+ f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask."
+ )
+ attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K]
+ attn_mask = attn_mask.any(dim=(1, 2)) # [B, K]
+
+ else:
+ raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}")
+
+ if attn_mask.shape != (batch_size, seq_len_k):
+ raise ValueError(
+ f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})"
+ )
+
+ return attn_mask
+
+
+def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return q_idx >= kv_idx
+
+
+# ===== torch op registrations =====
+# Registrations are required for fullgraph tracing compatibility
+
+
+# TODO: library.custom_op and register_fake probably need version guards?
+# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
+# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
+@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
+def _wrapped_flash_attn_3_original(
+ query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ out, lse = flash_attn_3_func(query, key, value)
+ lse = lse.permute(0, 2, 1)
+ return out, lse
+
+
+@torch.library.register_fake("flash_attn_3::_flash_attn_forward")
+def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size, seq_len, num_heads, head_dim = query.shape
+ lse_shape = (batch_size, seq_len, num_heads)
+ return torch.empty_like(query), query.new_empty(lse_shape)
+
+
+# ===== Attention backends =====
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ out = flash_attn_func(
+ q=query,
+ k=key,
+ v=value,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLASH_VARLEN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ alibi_slopes: Optional[torch.Tensor] = None,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = flash_attn_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ dropout_p=dropout_p,
+ softmax_scale=scale,
+ causal=is_causal,
+ window_size=window_size,
+ softcap=softcap,
+ alibi_slopes=alibi_slopes,
+ deterministic=deterministic,
+ return_attn_probs=return_attn_probs,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+) -> torch.Tensor:
+ out, lse, *_ = flash_attn_3_func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ attention_chunk=0,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_VARLEN_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _flash_varlen_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ window_size: Tuple[int, int] = (-1, -1),
+ softcap: float = 0.0,
+ deterministic: bool = False,
+ return_attn_probs: bool = False,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out, lse, *_ = flash_attn_3_varlen_func(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ seqused_q=None,
+ seqused_k=None,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return (out, lse) if return_attn_probs else out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.FLEX,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _native_flex_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ kernel_options: Optional[Dict[str, Any]] = None,
+) -> torch.Tensor:
+ # TODO: should we LRU cache the block mask creation?
+ score_mod = None
+ block_mask = None
+ batch_size, seq_len_q, num_heads, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask):
+ block_mask = attn_mask
+ elif is_causal:
+ block_mask = flex_attention.create_block_mask(
+ _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device
+ )
+ elif torch.is_tensor(attn_mask):
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+
+ attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv)
+
+ if attn_mask.dtype == torch.bool:
+ # TODO: this probably does not work but verify!
+ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ return attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+
+ block_mask = flex_attention.create_block_mask(
+ mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device
+ )
+ else:
+
+ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
+ return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx]
+ else:
+ raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.")
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = flex_attention.flex_attention(
+ query=query,
+ key=key,
+ value=value,
+ score_mod=score_mod,
+ block_mask=block_mask,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ return_lse=return_lse,
+ kernel_options=kernel_options,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.NATIVE,
+ constraints=[_check_device, _check_shape],
+)
+def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_CUDNN,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_EFFICIENT,
+ constraints=[_check_device, _check_shape],
+)
+def _native_efficient_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_FLASH,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_flash_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=None, # not supported
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_MATH,
+ constraints=[_check_device, _check_shape],
+)
+def _native_math_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_NPU,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _native_npu_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+) -> torch.Tensor:
+ return npu_fusion_attention(
+ query,
+ key,
+ value,
+ query.size(2), # num_heads
+ input_layout="BSND",
+ pse=None,
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=65536,
+ next_tokens=65536,
+ keep_prob=1.0 - dropout_p,
+ sync=False,
+ inner_precise=0,
+ )[0]
+
+
+# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_XLA,
+ constraints=[_check_device, _check_shape],
+)
+def _native_xla_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+) -> torch.Tensor:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ query = query / math.sqrt(query.shape[-1])
+ out = xla_flash_attention(
+ q=query,
+ k=key,
+ v=value,
+ causal=is_causal,
+ )
+ out = out.permute(0, 2, 1, 3)
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE_VARLEN,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+)
+def _sage_varlen_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ cu_seqlens_q: Optional[torch.Tensor] = None,
+ cu_seqlens_k: Optional[torch.Tensor] = None,
+ max_seqlen_q: Optional[int] = None,
+ max_seqlen_k: Optional[int] = None,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ smooth_k: bool = True,
+ attn_mask: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ batch_size, seq_len_q, _, _ = query.shape
+ _, seq_len_kv, _, _ = key.shape
+
+ if attn_mask is not None:
+ attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
+
+ if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
+ (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
+ _prepare_for_flash_attn_or_sage_varlen(
+ batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
+ )
+ )
+ else:
+ seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
+ cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
+ cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
+
+ key_valid, value_valid = [], []
+ for b in range(batch_size):
+ valid_len = seqlens_k[b]
+ key_valid.append(key[b, :valid_len])
+ value_valid.append(value[b, :valid_len])
+
+ query_packed = query.flatten(0, 1)
+ key_packed = torch.cat(key_valid, dim=0)
+ value_packed = torch.cat(value_valid, dim=0)
+
+ out = sageattn_varlen(
+ q=query_packed,
+ k=key_packed,
+ v=value_packed,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_q,
+ max_seqlen_k=max_seqlen_k,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ )
+ out = out.unflatten(0, (batch_size, -1))
+
+ return out
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
+ constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp8_cuda_sm90(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_cuda_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread",
+ pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32",
+ smooth_k: bool = True,
+ smooth_v: bool = False,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_cuda(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ qk_quant_gran=qk_quant_gran,
+ sm_scale=scale,
+ pv_accum_dtype=pv_accum_dtype,
+ smooth_k=smooth_k,
+ smooth_v=smooth_v,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
+ constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
+)
+def _sage_qk_int8_pv_fp16_triton_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton",
+ smooth_k: bool = True,
+ return_lse: bool = False,
+) -> torch.Tensor:
+ return sageattn_qk_int8_pv_fp16_triton(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ quantization_backend=quantization_backend,
+ is_causal=is_causal,
+ sm_scale=scale,
+ smooth_k=smooth_k,
+ return_lse=return_lse,
+ )
+
+
+@_AttentionBackendRegistry.register(
+ AttentionBackendName.XFORMERS,
+ constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
+)
+def _xformers_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+) -> torch.Tensor:
+ batch_size, seq_len_q, num_heads_q, _ = query.shape
+ _, seq_len_kv, num_heads_kv, _ = key.shape
+
+ if is_causal:
+ attn_mask = xops.LowerTriangularMask()
+ elif attn_mask is not None:
+ if attn_mask.ndim == 2:
+ attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
+ elif attn_mask.ndim != 4:
+ raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
+ attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
+
+ if enable_gqa:
+ if num_heads_q % num_heads_kv != 0:
+ raise ValueError("Number of heads in query must be divisible by number of heads in key/value.")
+ num_heads_per_group = num_heads_q // num_heads_kv
+ query = query.unflatten(2, (num_heads_kv, -1))
+ key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+ value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1)
+
+ out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale)
+
+ if enable_gqa:
+ out = out.flatten(2, 3)
+
+ return out
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 4760cfd40b..990245de17 100755
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -2272,558 +2272,6 @@ class FusedAuraFlowAttnProcessor2_0:
return hidden_states
-class FluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FusedFluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FusedFluxAttnProcessor2_0_NPU:
- """Attention processor used typically in processing the SD3-like self-attention projections."""
-
- def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- qkv = attn.to_qkv(hidden_states)
- split_size = qkv.shape[-1] // 3
- query, key, value = torch.split(qkv, split_size, dim=-1)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- # `context` projections.
- if encoder_hidden_states is not None:
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
- split_size = encoder_qkv.shape[-1] // 3
- (
- encoder_hidden_states_query_proj,
- encoder_hidden_states_key_proj,
- encoder_hidden_states_value_proj,
- ) = torch.split(encoder_qkv, split_size, dim=-1)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- if query.dtype in (torch.float16, torch.bfloat16):
- hidden_states = torch_npu.npu_fusion_attention(
- query,
- key,
- value,
- attn.heads,
- input_layout="BNSD",
- pse=None,
- scale=1.0 / math.sqrt(query.shape[-1]),
- pre_tockens=65536,
- next_tockens=65536,
- keep_prob=1.0,
- sync=False,
- inner_precise=0,
- )[0]
- else:
- hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
-class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
- """Flux Attention processor for IP-Adapter."""
-
- def __init__(
- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
- ):
- super().__init__()
-
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
-
- self.hidden_size = hidden_size
- self.cross_attention_dim = cross_attention_dim
-
- if not isinstance(num_tokens, (tuple, list)):
- num_tokens = [num_tokens]
-
- if not isinstance(scale, list):
- scale = [scale] * len(num_tokens)
- if len(scale) != len(num_tokens):
- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
- self.scale = scale
-
- self.to_k_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
- self.to_v_ip = nn.ModuleList(
- [
- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
- for _ in range(len(num_tokens))
- ]
- )
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ip_hidden_states: Optional[List[torch.Tensor]] = None,
- ip_adapter_masks: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- hidden_states_query_proj = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
- )
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- # IP-adapter
- ip_query = hidden_states_query_proj
- ip_attn_output = torch.zeros_like(hidden_states)
-
- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
- ):
- ip_key = to_k_ip(current_ip_hidden_states)
- ip_value = to_v_ip(current_ip_hidden_states)
-
- ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
- # TODO: add support for attn.scale when we move to Torch 2.1
- current_ip_hidden_states = F.scaled_dot_product_attention(
- ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
- )
- current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
- batch_size, -1, attn.heads * head_dim
- )
- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
- ip_attn_output += scale * current_ip_hidden_states
-
- return hidden_states, encoder_hidden_states, ip_attn_output
- else:
- return hidden_states
-
-
class CogVideoXAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
@@ -3453,106 +2901,6 @@ class XLAFlashAttnProcessor2_0:
return hidden_states
-class XLAFluxFlashAttnProcessor2_0:
- r"""
- Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
- """
-
- def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
- )
- if is_torch_xla_version("<", "2.3"):
- raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
- if is_spmd() and is_torch_xla_version("<", "2.4"):
- raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
- self.partition_spec = partition_spec
-
- def __call__(
- self,
- attn: Attention,
- hidden_states: torch.FloatTensor,
- encoder_hidden_states: torch.FloatTensor = None,
- attention_mask: Optional[torch.FloatTensor] = None,
- image_rotary_emb: Optional[torch.Tensor] = None,
- ) -> torch.FloatTensor:
- batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
-
- # `sample` projections.
- query = attn.to_q(hidden_states)
- key = attn.to_k(hidden_states)
- value = attn.to_v(hidden_states)
-
- inner_dim = key.shape[-1]
- head_dim = inner_dim // attn.heads
-
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
-
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
-
- # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
- if encoder_hidden_states is not None:
- # `context` projections.
- encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
- encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
- encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
-
- encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
- encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
- batch_size, -1, attn.heads, head_dim
- ).transpose(1, 2)
-
- if attn.norm_added_q is not None:
- encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
- if attn.norm_added_k is not None:
- encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
-
- # attention
- query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
- key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
- value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
-
- if image_rotary_emb is not None:
- from .embeddings import apply_rotary_emb
-
- query = apply_rotary_emb(query, image_rotary_emb)
- key = apply_rotary_emb(key, image_rotary_emb)
-
- query /= math.sqrt(head_dim)
- hidden_states = flash_attention(query, key, value, causal=False)
-
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
- hidden_states = hidden_states.to(query.dtype)
-
- if encoder_hidden_states is not None:
- encoder_hidden_states, hidden_states = (
- hidden_states[:, : encoder_hidden_states.shape[1]],
- hidden_states[:, encoder_hidden_states.shape[1] :],
- )
-
- # linear proj
- hidden_states = attn.to_out[0](hidden_states)
- # dropout
- hidden_states = attn.to_out[1](hidden_states)
-
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
-
- return hidden_states, encoder_hidden_states
- else:
- return hidden_states
-
-
class MochiVaeAttnProcessor2_0:
r"""
Attention processor used in Mochi VAE.
@@ -5992,17 +5340,6 @@ class LoRAAttnAddedKVProcessor:
pass
-class FluxSingleAttnProcessor2_0(FluxAttnProcessor2_0):
- r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
- """
-
- def __init__(self):
- deprecation_message = "`FluxSingleAttnProcessor2_0` is deprecated and will be removed in a future version. Please use `FluxAttnProcessor2_0` instead."
- deprecate("FluxSingleAttnProcessor2_0", "0.32.0", deprecation_message)
- super().__init__()
-
-
class SanaLinearAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product linear attention.
@@ -6167,6 +5504,111 @@ class PAGIdentitySanaLinearAttnProcessor2_0:
return hidden_states
+class FluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxSingleAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxSingleAttnProcessor` is deprecated and will be removed in a future version. Please use `FluxAttnProcessorSDPA` instead."
+ deprecate("FluxSingleAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FusedFluxAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FusedFluxAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxAttnProcessor`"
+ deprecate("FusedFluxAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ return FluxAttnProcessor(*args, **kwargs)
+
+
+class FluxIPAdapterJointAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`FluxIPAdapterJointAttnProcessor2_0` is deprecated and this will be removed in a future version. Please use `FluxIPAdapterAttnProcessor`"
+ deprecate("FluxIPAdapterJointAttnProcessor2_0", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxIPAdapterAttnProcessor
+
+ return FluxIPAdapterAttnProcessor(*args, **kwargs)
+
+
+class FluxAttnProcessor2_0_NPU:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "FluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_native_npu"
+ return processor
+
+
+class FusedFluxAttnProcessor2_0_NPU:
+ def __new__(self):
+ deprecation_message = (
+ "FusedFluxAttnProcessor2_0_NPU is deprecated and will be removed in a future version. An "
+ "alternative solution to use NPU Flash Attention will be provided in the future."
+ )
+ deprecate("FusedFluxAttnProcessor2_0_NPU", "1.0.0", deprecation_message, standard_warn=False)
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ processor = FluxAttnProcessor()
+ processor._attention_backend = "_fused_npu"
+ return processor
+
+
+class XLAFluxFlashAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "XLAFluxFlashAttnProcessor2_0 is deprecated and will be removed in diffusers 1.0.0. An "
+ "alternative solution to using XLA Flash Attention will be provided in the future."
+ )
+ deprecate("XLAFluxFlashAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+
+ if is_torch_xla_version("<", "2.3"):
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
+
+ from .transformers.transformer_flux import FluxAttnProcessor
+
+ if len(args) > 0 or kwargs.get("partition_spec", None) is not None:
+ deprecation_message = (
+ "partition_spec was not used in the processor implementation when it was added. Passing it "
+ "is a no-op and support for it will be removed."
+ )
+ deprecate("partition_spec", "1.0.0", deprecation_message)
+
+ processor = FluxAttnProcessor(*args, **kwargs)
+ processor._attention_backend = "_native_xla"
+ return processor
+
+
ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor,
diff --git a/src/diffusers/models/auto_model.py b/src/diffusers/models/auto_model.py
index 96785ce6f5..bfe386f1f6 100644
--- a/src/diffusers/models/auto_model.py
+++ b/src/diffusers/models/auto_model.py
@@ -117,8 +117,8 @@ class AutoModel(ConfigMixin):
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`. You can also activate the special
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
index 15fe8e02e0..7ab79a0bb8 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py
@@ -110,8 +110,11 @@ class CosmosPatchEmbed3d(nn.Module):
self.patch_size = patch_size
self.patch_method = patch_method
- self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
- self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=False)
+ wavelets = _WAVELETS.get(patch_method).clone()
+ arange = torch.arange(wavelets.shape[0])
+
+ self.register_buffer("wavelets", wavelets, persistent=False)
+ self.register_buffer("_arange", arange, persistent=False)
def _dwt(self, hidden_states: torch.Tensor, mode: str = "reflect", rescale=False) -> torch.Tensor:
dtype = hidden_states.dtype
@@ -185,12 +188,11 @@ class CosmosUnpatcher3d(nn.Module):
self.patch_size = patch_size
self.patch_method = patch_method
- self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=False)
- self.register_buffer(
- "_arange",
- torch.arange(_WAVELETS[patch_method].shape[0]),
- persistent=False,
- )
+ wavelets = _WAVELETS.get(patch_method).clone()
+ arange = torch.arange(wavelets.shape[0])
+
+ self.register_buffer("wavelets", wavelets, persistent=False)
+ self.register_buffer("_arange", arange, persistent=False)
def _idwt(self, hidden_states: torch.Tensor, rescale: bool = False) -> torch.Tensor:
device = hidden_states.device
diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
index 49cefcd8a1..608de25da5 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
CACHE_T = 2
+class AvgDown3D(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert in_channels * self.factor % out_channels == 0
+ self.group_size = in_channels * self.factor // out_channels
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
+ pad = (0, 0, 0, 0, pad_t, 0)
+ x = F.pad(x, pad)
+ B, C, T, H, W = x.shape
+ x = x.view(
+ B,
+ C,
+ T // self.factor_t,
+ self.factor_t,
+ H // self.factor_s,
+ self.factor_s,
+ W // self.factor_s,
+ self.factor_s,
+ )
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
+ x = x.view(
+ B,
+ C * self.factor,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.view(
+ B,
+ self.out_channels,
+ self.group_size,
+ T // self.factor_t,
+ H // self.factor_s,
+ W // self.factor_s,
+ )
+ x = x.mean(dim=2)
+ return x
+
+
+class DupUp3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ factor_t,
+ factor_s=1,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ self.factor_t = factor_t
+ self.factor_s = factor_s
+ self.factor = self.factor_t * self.factor_s * self.factor_s
+
+ assert out_channels * self.factor % in_channels == 0
+ self.repeats = out_channels * self.factor // in_channels
+
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
+ x = x.repeat_interleave(self.repeats, dim=1)
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ self.factor_t,
+ self.factor_s,
+ self.factor_s,
+ x.size(2),
+ x.size(3),
+ x.size(4),
+ )
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
+ x = x.view(
+ x.size(0),
+ self.out_channels,
+ x.size(2) * self.factor_t,
+ x.size(4) * self.factor_s,
+ x.size(6) * self.factor_s,
+ )
+ if first_chunk:
+ x = x[:, :, self.factor_t - 1 :, :, :]
+ return x
+
+
class WanCausalConv3d(nn.Conv3d):
r"""
A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
"""
- def __init__(self, dim: int, mode: str) -> None:
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
super().__init__()
self.dim = dim
self.mode = mode
+ # default to dim //2
+ if upsample_out_dim is None:
+ upsample_out_dim = dim // 2
+
# layers
if mode == "upsample2d":
self.resample = nn.Sequential(
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
elif mode == "upsample3d":
self.resample = nn.Sequential(
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
)
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
return x
+class WanResidualDownBlock(nn.Module):
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
+ super().__init__()
+
+ # Shortcut path with downsample
+ self.avg_shortcut = AvgDown3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_downsample else 1,
+ factor_s=2 if down_flag else 1,
+ )
+
+ # Main path with residual blocks and downsample
+ resnets = []
+ for _ in range(num_res_blocks):
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
+ in_dim = out_dim
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add the final downsample block
+ if down_flag:
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
+ self.downsampler = WanResample(out_dim, mode=mode)
+ else:
+ self.downsampler = None
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ x_copy = x.clone()
+ for resnet in self.resnets:
+ x = resnet(x, feat_cache, feat_idx)
+ if self.downsampler is not None:
+ x = self.downsampler(x, feat_cache, feat_idx)
+
+ return x + self.avg_shortcut(x_copy)
+
+
class WanEncoder3d(nn.Module):
r"""
A 3D encoder module.
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
def __init__(
self,
+ in_channels: int = 3,
dim=128,
z_dim=4,
dim_mult=[1, 2, 4, 4],
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
temperal_downsample=[True, True, False],
dropout=0.0,
non_linearity: str = "silu",
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
):
super().__init__()
self.dim = dim
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
scale = 1.0
# init block
- self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
# downsample blocks
self.down_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
- for _ in range(num_res_blocks):
- self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
- if scale in attn_scales:
- self.down_blocks.append(WanAttentionBlock(out_dim))
- in_dim = out_dim
+ if is_residual:
+ self.down_blocks.append(
+ WanResidualDownBlock(
+ in_dim,
+ out_dim,
+ dropout,
+ num_res_blocks,
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
+ down_flag=i != len(dim_mult) - 1,
+ )
+ )
+ else:
+ for _ in range(num_res_blocks):
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ self.down_blocks.append(WanAttentionBlock(out_dim))
+ in_dim = out_dim
- # downsample block
- if i != len(dim_mult) - 1:
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
- self.down_blocks.append(WanResample(out_dim, mode=mode))
- scale /= 2.0
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
+ scale /= 2.0
# middle blocks
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
return x
+class WanResidualUpBlock(nn.Module):
+ """
+ A block that handles upsampling for the WanVAE decoder.
+
+ Args:
+ in_dim (int): Input dimension
+ out_dim (int): Output dimension
+ num_res_blocks (int): Number of residual blocks
+ dropout (float): Dropout rate
+ temperal_upsample (bool): Whether to upsample on temporal dimension
+ up_flag (bool): Whether to upsample or not
+ non_linearity (str): Type of non-linearity to use
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ num_res_blocks: int,
+ dropout: float = 0.0,
+ temperal_upsample: bool = False,
+ up_flag: bool = False,
+ non_linearity: str = "silu",
+ ):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ if up_flag:
+ self.avg_shortcut = DupUp3D(
+ in_dim,
+ out_dim,
+ factor_t=2 if temperal_upsample else 1,
+ factor_s=2,
+ )
+ else:
+ self.avg_shortcut = None
+
+ # create residual blocks
+ resnets = []
+ current_dim = in_dim
+ for _ in range(num_res_blocks + 1):
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
+ current_dim = out_dim
+
+ self.resnets = nn.ModuleList(resnets)
+
+ # Add upsampling layer if needed
+ if up_flag:
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
+ else:
+ self.upsampler = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
+ """
+ Forward pass through the upsampling block.
+
+ Args:
+ x (torch.Tensor): Input tensor
+ feat_cache (list, optional): Feature cache for causal convolutions
+ feat_idx (list, optional): Feature index for cache management
+
+ Returns:
+ torch.Tensor: Output tensor
+ """
+ x_copy = x.clone()
+
+ for resnet in self.resnets:
+ if feat_cache is not None:
+ x = resnet(x, feat_cache, feat_idx)
+ else:
+ x = resnet(x)
+
+ if self.upsampler is not None:
+ if feat_cache is not None:
+ x = self.upsampler(x, feat_cache, feat_idx)
+ else:
+ x = self.upsampler(x)
+
+ if self.avg_shortcut is not None:
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
+
+ return x
+
+
class WanUpBlock(nn.Module):
"""
A block that handles upsampling for the WanVAE decoder.
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
self.gradient_checkpointing = False
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
"""
Forward pass through the upsampling block.
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
temperal_upsample=[False, True, True],
dropout=0.0,
non_linearity: str = "silu",
+ out_channels: int = 3,
+ is_residual: bool = False,
):
super().__init__()
self.dim = dim
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
# dimensions
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
# init block
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
self.up_blocks = nn.ModuleList([])
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
# residual (+attention) blocks
- if i > 0:
+ if i > 0 and not is_residual:
+ # wan vae 2.1
in_dim = in_dim // 2
- # Determine if we need upsampling
+ # determine if we need upsampling
+ up_flag = i != len(dim_mult) - 1
+ # determine upsampling mode, if not upsampling, set to None
upsample_mode = None
- if i != len(dim_mult) - 1:
- upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
-
+ if up_flag and temperal_upsample[i]:
+ upsample_mode = "upsample3d"
+ elif up_flag:
+ upsample_mode = "upsample2d"
# Create and add the upsampling block
- up_block = WanUpBlock(
- in_dim=in_dim,
- out_dim=out_dim,
- num_res_blocks=num_res_blocks,
- dropout=dropout,
- upsample_mode=upsample_mode,
- non_linearity=non_linearity,
- )
+ if is_residual:
+ up_block = WanResidualUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
+ up_flag=up_flag,
+ non_linearity=non_linearity,
+ )
+ else:
+ up_block = WanUpBlock(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ num_res_blocks=num_res_blocks,
+ dropout=dropout,
+ upsample_mode=upsample_mode,
+ non_linearity=non_linearity,
+ )
self.up_blocks.append(up_block)
- # Update scale for next iteration
- if upsample_mode is not None:
- scale *= 2.0
-
# output blocks
self.norm_out = WanRMS_norm(out_dim, images=False)
- self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
self.gradient_checkpointing = False
- def forward(self, x, feat_cache=None, feat_idx=[0]):
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
## conv1
if feat_cache is not None:
idx = feat_idx[0]
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
## upsamples
for up_block in self.up_blocks:
- x = up_block(x, feat_cache, feat_idx)
+ x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
## head
x = self.norm_out(x)
@@ -656,6 +909,77 @@ class WanDecoder3d(nn.Module):
return x
+def patchify(x, patch_size):
+ if patch_size == 1:
+ return x
+
+ if x.dim() == 4:
+ # x shape: [batch_size, channels, height, width]
+ batch_size, channels, height, width = x.shape
+
+ # Ensure height and width are divisible by patch_size
+ if height % patch_size != 0 or width % patch_size != 0:
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
+
+ # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
+ x = x.view(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
+
+ # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
+ x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
+ x = x.view(batch_size, channels * patch_size * patch_size, height // patch_size, width // patch_size)
+
+ elif x.dim() == 5:
+ # x shape: [batch_size, channels, frames, height, width]
+ batch_size, channels, frames, height, width = x.shape
+
+ # Ensure height and width are divisible by patch_size
+ if height % patch_size != 0 or width % patch_size != 0:
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
+
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
+
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
+ x = x.permute(0, 1, 4, 6, 2, 3, 5).contiguous()
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
+
+ else:
+ raise ValueError(f"Invalid input shape: {x.shape}")
+
+ return x
+
+
+def unpatchify(x, patch_size):
+ if patch_size == 1:
+ return x
+
+ if x.dim() == 4:
+ # x shape: [b, (c * patch_size * patch_size), h, w]
+ batch_size, c_patches, height, width = x.shape
+ channels = c_patches // (patch_size * patch_size)
+
+ # Reshape to [b, c, patch_size, patch_size, h, w]
+ x = x.view(batch_size, channels, patch_size, patch_size, height, width)
+
+ # Rearrange to [b, c, h * patch_size, w * patch_size]
+ x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
+ x = x.view(batch_size, channels, height * patch_size, width * patch_size)
+
+ elif x.dim() == 5:
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
+ batch_size, c_patches, frames, height, width = x.shape
+ channels = c_patches // (patch_size * patch_size)
+
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
+
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
+ x = x.permute(0, 1, 4, 5, 2, 6, 3).contiguous()
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
+
+ return x
+
+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -671,6 +995,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def __init__(
self,
base_dim: int = 96,
+ decoder_base_dim: Optional[int] = None,
z_dim: int = 16,
dim_mult: Tuple[int] = [1, 2, 4, 4],
num_res_blocks: int = 2,
@@ -713,6 +1038,13 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
2.8251,
1.9160,
],
+ is_residual: bool = False,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ patch_size: Optional[int] = None,
+ scale_factor_temporal: Optional[int] = 4,
+ scale_factor_spatial: Optional[int] = 8,
+ clip_output: bool = True,
) -> None:
super().__init__()
@@ -720,14 +1052,33 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self.temperal_downsample = temperal_downsample
self.temperal_upsample = temperal_downsample[::-1]
+ if decoder_base_dim is None:
+ decoder_base_dim = base_dim
+
self.encoder = WanEncoder3d(
- base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
+ in_channels=in_channels,
+ dim=base_dim,
+ z_dim=z_dim * 2,
+ dim_mult=dim_mult,
+ num_res_blocks=num_res_blocks,
+ attn_scales=attn_scales,
+ temperal_downsample=temperal_downsample,
+ dropout=dropout,
+ is_residual=is_residual,
)
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
self.decoder = WanDecoder3d(
- base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
+ dim=decoder_base_dim,
+ z_dim=z_dim,
+ dim_mult=dim_mult,
+ num_res_blocks=num_res_blocks,
+ attn_scales=attn_scales,
+ temperal_upsample=self.temperal_upsample,
+ dropout=dropout,
+ out_channels=out_channels,
+ is_residual=is_residual,
)
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
@@ -827,6 +1178,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
return self.tiled_encode(x)
self.clear_cache()
+ if self.config.patch_size is not None:
+ x = patchify(x, patch_size=self.config.patch_size)
iter_ = 1 + (num_frame - 1) // 4
for i in range(iter_):
self._enc_conv_idx = [0]
@@ -884,12 +1237,17 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for i in range(num_frame):
self._conv_idx = [0]
if i == 0:
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
+ out = self.decoder(
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
+ )
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
- out = torch.clamp(out, min=-1.0, max=1.0)
+ if self.config.clip_output:
+ out = torch.clamp(out, min=-1.0, max=1.0)
+ if self.config.patch_size is not None:
+ out = unpatchify(out, patch_size=self.config.patch_size)
self.clear_cache()
if not return_dict:
return (out,)
diff --git a/src/diffusers/models/cache_utils.py b/src/diffusers/models/cache_utils.py
index 3fd1ca6e9d..605c0d588c 100644
--- a/src/diffusers/models/cache_utils.py
+++ b/src/diffusers/models/cache_utils.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from contextlib import contextmanager
+
from ..utils.logging import get_logger
@@ -25,6 +27,7 @@ class CacheMixin:
Supported caching techniques:
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
- [FasterCache](https://huggingface.co/papers/2410.19355)
+ - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
"""
_cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
from ..hooks import (
FasterCacheConfig,
+ FirstBlockCacheConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
+ apply_first_block_cache,
apply_pyramid_attention_broadcast,
)
@@ -72,31 +77,36 @@ class CacheMixin:
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
)
- if isinstance(config, PyramidAttentionBroadcastConfig):
- apply_pyramid_attention_broadcast(self, config)
- elif isinstance(config, FasterCacheConfig):
+ if isinstance(config, FasterCacheConfig):
apply_faster_cache(self, config)
+ elif isinstance(config, FirstBlockCacheConfig):
+ apply_first_block_cache(self, config)
+ elif isinstance(config, PyramidAttentionBroadcastConfig):
+ apply_pyramid_attention_broadcast(self, config)
else:
raise ValueError(f"Cache config {type(config)} is not supported.")
self._cache_config = config
def disable_cache(self) -> None:
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
+ from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
if self._cache_config is None:
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
return
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
- elif isinstance(self._cache_config, FasterCacheConfig):
- registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ if isinstance(self._cache_config, FasterCacheConfig):
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
else:
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
@@ -106,3 +116,15 @@ class CacheMixin:
from ..hooks import HookRegistry
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
+
+ @contextmanager
+ def cache_context(self, name: str):
+ r"""Context manager that provides additional methods for cache management."""
+ from ..hooks import HookRegistry
+
+ registry = HookRegistry.check_if_exists_or_initialize(self)
+ registry._set_context(name)
+
+ yield
+
+ registry._set_context(None)
diff --git a/src/diffusers/models/controlnets/controlnet_flux.py b/src/diffusers/models/controlnets/controlnet_flux.py
index d8e99ee45e..063ff5bd8e 100644
--- a/src/diffusers/models/controlnets/controlnet_flux.py
+++ b/src/diffusers/models/controlnets/controlnet_flux.py
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
block_samples = block_samples + (hidden_states,)
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
-
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
+ single_block_samples = single_block_samples + (hidden_states,)
# controlnet block
controlnet_block_samples = ()
diff --git a/src/diffusers/models/controlnets/controlnet_union.py b/src/diffusers/models/controlnets/controlnet_union.py
index d731ef137b..3df3bbe312 100644
--- a/src/diffusers/models/controlnets/controlnet_union.py
+++ b/src/diffusers/models/controlnets/controlnet_union.py
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[control_idx]
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)
else:
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
controlnet_cond_fuser += condition + alpha
else:
controlnet_cond_fuser += condition + alpha * scale
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
- if from_multi:
+ if from_multi or len(control_type_idx) == 1:
scales = scales * conditioning_scale[0]
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
- elif from_multi:
+ elif from_multi or len(control_type_idx) == 1:
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py
index 4f268bfa01..b51f5d7aec 100644
--- a/src/diffusers/models/embeddings.py
+++ b/src/diffusers/models/embeddings.py
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
return emb
-def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
"""
This function generates 1D positional embeddings from a grid.
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
emb_cos = torch.cos(out) # (M, D/2)
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
+
+ # flip sine and cosine embeddings
+ if flip_sin_to_cos:
+ emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
+
return emb
@@ -1176,6 +1181,7 @@ def apply_rotary_emb(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
use_real: bool = True,
use_real_unbind_dim: int = -1,
+ sequence_dim: int = 2,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
@@ -1193,8 +1199,15 @@ def apply_rotary_emb(
"""
if use_real:
cos, sin = freqs_cis # [S, D]
- cos = cos[None, None]
- sin = sin[None, None]
+ if sequence_dim == 2:
+ cos = cos[None, None, :, :]
+ sin = sin[None, None, :, :]
+ elif sequence_dim == 1:
+ cos = cos[None, :, None, :]
+ sin = sin[None, :, None, :]
+ else:
+ raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
+
cos, sin = cos.to(x.device), sin.to(x.device)
if use_real_unbind_dim == -1:
@@ -1238,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
return x
-class FluxPosEmbed(nn.Module):
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
- def __init__(self, theta: int, axes_dim: List[int]):
- super().__init__()
- self.theta = theta
- self.axes_dim = axes_dim
-
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
- n_axes = ids.shape[-1]
- cos_out = []
- sin_out = []
- pos = ids.float()
- is_mps = ids.device.type == "mps"
- is_npu = ids.device.type == "npu"
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
- for i in range(n_axes):
- cos, sin = get_1d_rotary_pos_embed(
- self.axes_dim[i],
- pos[:, i],
- theta=self.theta,
- repeat_interleave_real=True,
- use_real=True,
- freqs_dtype=freqs_dtype,
- )
- cos_out.append(cos)
- sin_out.append(sin)
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
- return freqs_cos, freqs_sin
-
-
class TimestepEmbedding(nn.Module):
def __init__(
self,
@@ -2619,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
projected_image_embeds.append(image_embed)
return projected_image_embeds
+
+
+class FluxPosEmbed(nn.Module):
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
+ deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
+
+ from .transformers.transformer_flux import FluxPosEmbed
+
+ return FluxPosEmbed(*args, **kwargs)
diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py
index 419e757ce4..23f9a5c8b6 100644
--- a/src/diffusers/models/model_loading_utils.py
+++ b/src/diffusers/models/model_loading_utils.py
@@ -16,9 +16,10 @@
import importlib
import inspect
+import math
import os
from array import array
-from collections import OrderedDict
+from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -39,6 +40,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
+ is_accelerate_version,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -254,6 +256,10 @@ def load_model_dict_into_meta(
param = param.to(dtype)
set_module_kwargs["dtype"] = dtype
+ if is_accelerate_version(">", "1.8.1"):
+ set_module_kwargs["non_blocking"] = True
+ set_module_kwargs["clear_cache"] = False
+
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
@@ -523,3 +529,60 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
return parsed_parameters
+
+
+def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
+ mismatched_keys = []
+ if not ignore_mismatched_sizes:
+ return mismatched_keys
+ for checkpoint_key in loaded_keys:
+ model_key = checkpoint_key
+ # If the checkpoint is sharded, we may not have the key here.
+ if checkpoint_key not in state_dict:
+ continue
+
+ if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
+ mismatched_keys.append(
+ (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
+ )
+ del state_dict[checkpoint_key]
+ return mismatched_keys
+
+
+def _expand_device_map(device_map, param_names):
+ """
+ Expand a device map to return the correspondence parameter name to device.
+ """
+ new_device_map = {}
+ for module, device in device_map.items():
+ new_device_map.update(
+ {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
+ )
+ return new_device_map
+
+
+# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
+def _caching_allocator_warmup(model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype) -> None:
+ """
+ This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
+ device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
+ which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
+ very large margin.
+ """
+ # Remove disk and cpu devices, and cast to proper torch.device
+ accelerator_device_map = {
+ param: torch.device(device)
+ for param, device in expanded_device_map.items()
+ if str(device) not in ["cpu", "disk"]
+ }
+ parameter_count = defaultdict(lambda: 0)
+ for param_name, device in accelerator_device_map.items():
+ try:
+ param = model.get_parameter(param_name)
+ except AttributeError:
+ param = model.get_buffer(param_name)
+ parameter_count[device] += math.prod(param.shape)
+
+ # This will kick off the caching allocator to avoid having to Malloc afterwards
+ for device, param_count in parameter_count.items():
+ _ = torch.empty(param_count, dtype=dtype, device=device, requires_grad=False)
diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py
index 52f004f6f9..010b737745 100644
--- a/src/diffusers/models/modeling_flax_utils.py
+++ b/src/diffusers/models/modeling_flax_utils.py
@@ -369,8 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `token` or log in with `huggingface-cli "
- "login`."
+ "token having permission to this repo with `token` or log in with `hf auth login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py
index 5fa04fb260..815f12a707 100644
--- a/src/diffusers/models/modeling_utils.py
+++ b/src/diffusers/models/modeling_utils.py
@@ -62,10 +62,14 @@ from ..utils.hub_utils import (
load_or_create_model_card,
populate_model_card,
)
+from ..utils.torch_utils import empty_device_cache
from .model_loading_utils import (
+ _caching_allocator_warmup,
_determine_device_map,
+ _expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
+ _find_mismatched_keys,
_load_state_dict_into_model,
load_model_dict_into_meta,
load_state_dict,
@@ -168,7 +172,11 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
for name, param in parameter.named_parameters():
last_dtype = param.dtype
- if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
+ if (
+ hasattr(parameter, "_keep_in_fp32_modules")
+ and parameter._keep_in_fp32_modules
+ and any(m in name for m in parameter._keep_in_fp32_modules)
+ ):
continue
if param.is_floating_point():
@@ -266,6 +274,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_keep_in_fp32_modules = None
_skip_layerwise_casting_patterns = None
_supports_group_offloading = True
+ _repeated_blocks = []
def __init__(self):
super().__init__()
@@ -601,6 +610,60 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_to_disk_path=offload_to_disk_path,
)
+ def set_attention_backend(self, backend: str) -> None:
+ """
+ Set the attention backend for the model.
+
+ Args:
+ backend (`str`):
+ The name of the backend to set. Must be one of the available backends defined in
+ `AttentionBackendName`. Available backends can be found in
+ `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
+ attention as backend.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
+
+ # TODO: the following will not be required when everything is refactored to AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
+
+ backend = backend.lower()
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
+ if backend not in available_backends:
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
+ backend = AttentionBackendName(backend)
+ _check_attention_backend_requirements(backend)
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = backend
+
+ def reset_attention_backend(self) -> None:
+ """
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default or
+ the torch native scaled dot product attention.
+ """
+ from .attention import AttentionModuleMixin
+ from .attention_processor import Attention, MochiAttention
+
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
+
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
+ for module in self.modules():
+ if not isinstance(module, attention_classes):
+ continue
+ processor = module.processor
+ if processor is None or not hasattr(processor, "_attention_backend"):
+ continue
+ processor._attention_backend = None
+
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
@@ -880,8 +943,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`. You can also activate the special
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
@@ -1404,6 +1467,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
+ def compile_repeated_blocks(self, *args, **kwargs):
+ """
+ Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
+ compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
+ https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
+ substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
+
+ The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
+ model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
+ module whose class name matches will be compiled.
+
+ Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
+ positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
+ `torch.compile`.
+ """
+ repeated_blocks = getattr(self, "_repeated_blocks", None)
+
+ if not repeated_blocks:
+ raise ValueError(
+ "`_repeated_blocks` attribute is empty. "
+ f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
+ )
+ has_compiled_region = False
+ for submod in self.modules():
+ if submod.__class__.__name__ in repeated_blocks:
+ submod.compile(*args, **kwargs)
+ has_compiled_region = True
+
+ if not has_compiled_region:
+ raise ValueError(
+ f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
+ )
+
@classmethod
def _load_pretrained_model(
cls,
@@ -1435,11 +1531,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
- mismatched_keys = []
-
- assign_to_params_buffers = None
- error_msgs = []
-
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
@@ -1448,18 +1539,27 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
- if offload_folder is not None:
+ else:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True
+ # If a device map has been used, we can speedup the load time by warming up the device caching allocator.
+ # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
+ # lot of individual calls to device malloc). We can, however, preallocate the memory required by the
+ # tensors using their expected shape and not performing any initialization of the memory (empty data).
+ # When the actual device allocations happen, the allocator already has a pool of unused device memory
+ # that it can re-use for faster loading of the model.
+ # TODO: add support for warmup with hf_quantizer
+ if device_map is not None and hf_quantizer is None:
+ expanded_device_map = _expand_device_map(device_map, expected_keys)
+ _caching_allocator_warmup(model, expanded_device_map, dtype)
+
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
+ state_dict_folder, state_dict_index = None, None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
- else:
- state_dict_folder = None
- state_dict_index = None
if state_dict is not None:
# load_state_dict will manage the case where we pass a dict instead of a file
@@ -1469,38 +1569,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
if len(resolved_model_file) > 1:
resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
+ mismatched_keys = []
+ assign_to_params_buffers = None
+ error_msgs = []
+
for shard_file in resolved_model_file:
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
-
- def _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
- ):
- mismatched_keys = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
- # If the checkpoint is sharded, we may not have the key here.
- if checkpoint_key not in state_dict:
- continue
-
- if (
- model_key in model_state_dict
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
- ):
- mismatched_keys.append(
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
- )
- del state_dict[checkpoint_key]
- return mismatched_keys
-
mismatched_keys += _find_mismatched_keys(
- state_dict,
- model_state_dict,
- loaded_keys,
- ignore_mismatched_sizes,
+ state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes
)
if low_cpu_mem_usage:
@@ -1520,9 +1596,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
if assign_to_params_buffers is None:
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
-
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
+ empty_device_cache()
+
if offload_index is not None and len(offload_index) > 0:
save_offload_index(offload_index, offload_folder)
offload_index = None
@@ -1858,4 +1935,9 @@ class LegacyModelMixin(ModelMixin):
# resolve remapping
remapped_class = _fetch_remapped_cls_from_config(config, cls)
- return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
+ if remapped_class is cls:
+ return super(LegacyModelMixin, remapped_class).from_pretrained(
+ pretrained_model_name_or_path, **kwargs_copy
+ )
+ else:
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index cc03a0ccbc..dd8813369b 100755
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -31,6 +31,7 @@ if is_torch_available():
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
+ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
from .transformer_wan import WanTransformer3DModel
from .transformer_wan_vace import WanVACETransformer3DModel
diff --git a/src/diffusers/models/transformers/transformer_chroma.py b/src/diffusers/models/transformers/transformer_chroma.py
index d11f6c2a5e..5823ae9d3d 100644
--- a/src/diffusers/models/transformers/transformer_chroma.py
+++ b/src/diffusers/models/transformers/transformer_chroma.py
@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
+from ..attention import AttentionMixin, FeedForward
from ..cache_utils import CacheMixin
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
+from .transformer_flux import FluxAttention, FluxAttnProcessor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
+
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
- processor = FluxAttnProcessor2_0()
+ processor = FluxAttnProcessor()
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
- qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
+ processor=FluxAttnProcessor(),
eps=eps,
)
@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
class ChromaTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux, modified for Chroma.
@@ -407,6 +405,7 @@ class ChromaTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
+ _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
@register_to_config
@@ -474,106 +473,6 @@ class ChromaTransformer2DModel(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/src/diffusers/models/transformers/transformer_cogview4.py b/src/diffusers/models/transformers/transformer_cogview4.py
index e4144d0c8e..dc45befb98 100644
--- a/src/diffusers/models/transformers/transformer_cogview4.py
+++ b/src/diffusers/models/transformers/transformer_cogview4.py
@@ -21,6 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
return hidden_states, encoder_hidden_states
+@maybe_allow_in_graph
class CogView4TransformerBlock(nn.Module):
def __init__(
self,
diff --git a/src/diffusers/models/transformers/transformer_cosmos.py b/src/diffusers/models/transformers/transformer_cosmos.py
index 6c312b7a5a..373b470ae3 100644
--- a/src/diffusers/models/transformers/transformer_cosmos.py
+++ b/src/diffusers/models/transformers/transformer_cosmos.py
@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin
from ...utils import is_torchvision_available
from ..attention import FeedForward
from ..attention_processor import Attention
@@ -186,9 +187,15 @@ class CosmosAttnProcessor2_0:
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
# 4. Prepare for GQA
- query_idx = torch.tensor(query.size(3), device=query.device)
- key_idx = torch.tensor(key.size(3), device=key.device)
- value_idx = torch.tensor(value.size(3), device=value.device)
+ if torch.onnx.is_in_onnx_export():
+ query_idx = torch.tensor(query.size(3), device=query.device)
+ key_idx = torch.tensor(key.size(3), device=key.device)
+ value_idx = torch.tensor(value.size(3), device=value.device)
+
+ else:
+ query_idx = query.size(3)
+ key_idx = key.size(3)
+ value_idx = value.size(3)
key = key.repeat_interleave(query_idx // key_idx, dim=3)
value = value.repeat_interleave(query_idx // value_idx, dim=3)
@@ -377,7 +384,7 @@ class CosmosLearnablePositionalEmbed(nn.Module):
return (emb / norm).type_as(hidden_states)
-class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
+class CosmosTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py
index ab579a0eb5..9080cd508d 100644
--- a/src/diffusers/models/transformers/transformer_flux.py
+++ b/src/diffusers/models/transformers/transformer_flux.py
@@ -12,28 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-from typing import Any, Dict, Optional, Tuple, Union
+import inspect
+from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
+import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.import_utils import is_torch_npu_available
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import (
- Attention,
- AttentionProcessor,
- FluxAttnProcessor2_0,
- FluxAttnProcessor2_0_NPU,
- FusedFluxAttnProcessor2_0,
-)
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
-from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
+from ..embeddings import (
+ CombinedTimestepGuidanceTextProjEmbeddings,
+ CombinedTimestepTextProjEmbeddings,
+ apply_rotary_emb,
+ get_1d_rotary_pos_embed,
+)
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
@@ -42,6 +42,307 @@ from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNo
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ encoder_query = encoder_key = encoder_value = None
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+
+ encoder_query = encoder_key = encoder_value = (None,)
+ if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"):
+ encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1)
+
+ return query, key, value, encoder_query, encoder_key, encoder_value
+
+
+def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
+ if attn.fused_projections:
+ return _get_fused_projections(attn, hidden_states, encoder_hidden_states)
+ return _get_projections(attn, hidden_states, encoder_hidden_states)
+
+
+class FluxAttnProcessor:
+ _attention_backend = None
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if attn.added_kv_proj_dim is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query, key, value, attn_mask=attention_mask, backend=self._attention_backend
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ return hidden_states
+
+
+class FluxIPAdapterAttnProcessor(torch.nn.Module):
+ """Flux Attention processor for IP-Adapter."""
+
+ _attention_backend = None
+
+ def __init__(
+ self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
+ ):
+ super().__init__()
+
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ )
+
+ self.hidden_size = hidden_size
+ self.cross_attention_dim = cross_attention_dim
+
+ if not isinstance(num_tokens, (tuple, list)):
+ num_tokens = [num_tokens]
+
+ if not isinstance(scale, list):
+ scale = [scale] * len(num_tokens)
+ if len(scale) != len(num_tokens):
+ raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
+ self.scale = scale
+
+ self.to_k_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+ self.to_v_ip = nn.ModuleList(
+ [
+ nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
+ for _ in range(len(num_tokens))
+ ]
+ )
+
+ def __call__(
+ self,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ip_hidden_states: Optional[List[torch.Tensor]] = None,
+ ip_adapter_masks: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size = hidden_states.shape[0]
+
+ query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
+ attn, hidden_states, encoder_hidden_states
+ )
+
+ query = query.unflatten(-1, (attn.heads, -1))
+ key = key.unflatten(-1, (attn.heads, -1))
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+ ip_query = query
+
+ if encoder_hidden_states is not None:
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+
+ encoder_query = attn.norm_added_q(encoder_query)
+ encoder_key = attn.norm_added_k(encoder_key)
+
+ query = torch.cat([encoder_query, query], dim=1)
+ key = torch.cat([encoder_key, key], dim=1)
+ value = torch.cat([encoder_value, value], dim=1)
+
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ if encoder_hidden_states is not None:
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ # IP-adapter
+ ip_attn_output = torch.zeros_like(hidden_states)
+
+ for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
+ ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
+ ):
+ ip_key = to_k_ip(current_ip_hidden_states)
+ ip_value = to_v_ip(current_ip_hidden_states)
+
+ ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim)
+ ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim)
+
+ current_ip_hidden_states = dispatch_attention_fn(
+ ip_query,
+ ip_key,
+ ip_value,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ )
+ current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim)
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
+ ip_attn_output += scale * current_ip_hidden_states
+
+ return hidden_states, encoder_hidden_states, ip_attn_output
+ else:
+ return hidden_states
+
+
+class FluxAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = FluxAttnProcessor
+ _available_processors = [
+ FluxAttnProcessor,
+ FluxIPAdapterAttnProcessor,
+ ]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ eps: float = 1e-5,
+ out_dim: int = None,
+ context_pre_only: Optional[bool] = None,
+ pre_only: bool = False,
+ elementwise_affine: bool = True,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.head_dim = dim_head
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.added_proj_bias = added_proj_bias
+
+ self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+
+ if not self.pre_only:
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if added_kv_proj_dim is not None:
+ self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
+ self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
+ self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
@@ -54,6 +355,8 @@ class FluxSingleTransformerBlock(nn.Module):
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
if is_torch_npu_available():
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
+
deprecation_message = (
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
"should be set explicitly using the `set_attn_processor` method."
@@ -61,17 +364,15 @@ class FluxSingleTransformerBlock(nn.Module):
deprecate("npu_processor", "0.34.0", deprecation_message)
processor = FluxAttnProcessor2_0_NPU()
else:
- processor = FluxAttnProcessor2_0()
+ processor = FluxAttnProcessor()
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
- qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
@@ -79,10 +380,14 @@ class FluxSingleTransformerBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -100,7 +405,8 @@ class FluxSingleTransformerBlock(nn.Module):
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
- return hidden_states
+ encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
+ return encoder_hidden_states, hidden_states
@maybe_allow_in_graph
@@ -113,17 +419,15 @@ class FluxTransformerBlock(nn.Module):
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
- self.attn = Attention(
+ self.attn = FluxAttention(
query_dim=dim,
- cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
- processor=FluxAttnProcessor2_0(),
- qk_norm=qk_norm,
+ processor=FluxAttnProcessor(),
eps=eps,
)
@@ -147,6 +451,7 @@ class FluxTransformerBlock(nn.Module):
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
+
# Attention.
attention_outputs = self.attn(
hidden_states=norm_hidden_states,
@@ -175,7 +480,6 @@ class FluxTransformerBlock(nn.Module):
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`.
-
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = encoder_hidden_states + context_attn_output
@@ -190,8 +494,45 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
+class FluxPosEmbed(nn.Module):
+ # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ cos_out = []
+ sin_out = []
+ pos = ids.float()
+ is_mps = ids.device.type == "mps"
+ is_npu = ids.device.type == "npu"
+ freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
+ for i in range(n_axes):
+ cos, sin = get_1d_rotary_pos_embed(
+ self.axes_dim[i],
+ pos[:, i],
+ theta=self.theta,
+ repeat_interleave_real=True,
+ use_real=True,
+ freqs_dtype=freqs_dtype,
+ )
+ cos_out.append(cos)
+ sin_out.append(sin)
+ freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
+ freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
+ return freqs_cos, freqs_sin
+
+
class FluxTransformer2DModel(
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
+ ModelMixin,
+ ConfigMixin,
+ PeftAdapterMixin,
+ FromOriginalModelMixin,
+ FluxTransformer2DLoadersMixin,
+ CacheMixin,
+ AttentionMixin,
):
"""
The Transformer model introduced in Flux.
@@ -227,6 +568,7 @@ class FluxTransformer2DModel(
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
+ _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(
@@ -286,106 +628,6 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
- @property
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
- r"""
- Returns:
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
- indexed by its weight name.
- """
- # set recursively
- processors = {}
-
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
- if hasattr(module, "get_processor"):
- processors[f"{name}.processor"] = module.get_processor()
-
- for sub_name, child in module.named_children():
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
-
- return processors
-
- for name, module in self.named_children():
- fn_recursive_add_processors(name, module, processors)
-
- return processors
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
- r"""
- Sets the attention processor to use to compute attention.
-
- Parameters:
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
- for **all** `Attention` layers.
-
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
- processor. This is strongly recommended when setting trainable attention processors.
-
- """
- count = len(self.attn_processors.keys())
-
- if isinstance(processor, dict) and len(processor) != count:
- raise ValueError(
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
- )
-
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
- if hasattr(module, "set_processor"):
- if not isinstance(processor, dict):
- module.set_processor(processor)
- else:
- module.set_processor(processor.pop(f"{name}.processor"))
-
- for sub_name, child in module.named_children():
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
-
- for name, module in self.named_children():
- fn_recursive_attn_processor(name, module, processor)
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
- def fuse_qkv_projections(self):
- """
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
- are fused. For cross-attention modules, key and value projection matrices are fused.
-
-
-
- This API is 🧪 experimental.
-
-
- """
- self.original_attn_processors = None
-
- for _, attn_processor in self.attn_processors.items():
- if "Added" in str(attn_processor.__class__.__name__):
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
-
- self.original_attn_processors = self.attn_processors
-
- for module in self.modules():
- if isinstance(module, Attention):
- module.fuse_projections(fuse=True)
-
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
-
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
- def unfuse_qkv_projections(self):
- """Disables the fused QKV projection if enabled.
-
-
-
- This API is 🧪 experimental.
-
-
-
- """
- if self.original_attn_processors is not None:
- self.set_attn_processor(self.original_attn_processors)
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -484,6 +726,7 @@ class FluxTransformer2DModel(
encoder_hidden_states,
temb,
image_rotary_emb,
+ joint_attention_kwargs,
)
else:
@@ -506,20 +749,22 @@ class FluxTransformer2DModel(
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
- hidden_states = self._gradient_checkpointing_func(
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
+ encoder_hidden_states,
temb,
image_rotary_emb,
+ joint_attention_kwargs,
)
else:
- hidden_states = block(
+ encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
@@ -529,12 +774,7 @@ class FluxTransformer2DModel(
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
- hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
- hidden_states[:, encoder_hidden_states.shape[1] :, ...]
- + controlnet_single_block_samples[index_block // interval_control]
- )
-
- hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
diff --git a/src/diffusers/models/transformers/transformer_hunyuan_video.py b/src/diffusers/models/transformers/transformer_hunyuan_video.py
index c48c586a28..6944a6c536 100644
--- a/src/diffusers/models/transformers/transformer_hunyuan_video.py
+++ b/src/diffusers/models/transformers/transformer_hunyuan_video.py
@@ -870,6 +870,12 @@ class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin,
"HunyuanVideoPatchEmbed",
"HunyuanVideoTokenRefiner",
]
+ _repeated_blocks = [
+ "HunyuanVideoTransformerBlock",
+ "HunyuanVideoSingleTransformerBlock",
+ "HunyuanVideoPatchEmbed",
+ "HunyuanVideoTokenRefiner",
+ ]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_ltx.py b/src/diffusers/models/transformers/transformer_ltx.py
index 38b7b6af50..79149fb760 100644
--- a/src/diffusers/models/transformers/transformer_ltx.py
+++ b/src/diffusers/models/transformers/transformer_ltx.py
@@ -1,4 +1,4 @@
-# Copyright 2025 The Genmo team and The HuggingFace Team.
+# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,19 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import math
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
-import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ...utils import USE_PEFT_BACKEND, deprecate, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
-from ..attention import FeedForward
-from ..attention_processor import Attention
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
@@ -37,20 +37,30 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LTXVideoAttentionProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = "`LTXVideoAttentionProcessor2_0` is deprecated and this will be removed in a future version. Please use `LTXVideoAttnProcessor`"
+ deprecate("LTXVideoAttentionProcessor2_0", "1.0.0", deprecation_message)
+
+ return LTXVideoAttnProcessor(*args, **kwargs)
+
+
+class LTXVideoAttnProcessor:
r"""
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
- used in the LTX model. It applies a normalization layer and rotary embedding on the query and key vector.
+ Processor for implementing attention (SDPA is used by default if you're using PyTorch 2.0). This is used in the LTX
+ model. It applies a normalization layer and rotary embedding on the query and key vector.
"""
+ _attention_backend = None
+
def __init__(self):
- if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError(
- "LTXVideoAttentionProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
+ if is_torch_version("<", "2.0"):
+ raise ValueError(
+ "LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
- attn: Attention,
+ attn: "LTXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -78,14 +88,20 @@ class LTXVideoAttentionProcessor2_0:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.to_out[0](hidden_states)
@@ -93,6 +109,70 @@ class LTXVideoAttentionProcessor2_0:
return hidden_states
+class LTXAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = LTXVideoAttnProcessor
+ _available_processors = [LTXVideoAttnProcessor]
+
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ kv_heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = True,
+ cross_attention_dim: Optional[int] = None,
+ out_bias: bool = True,
+ qk_norm: str = "rms_norm_across_heads",
+ processor=None,
+ ):
+ super().__init__()
+ if qk_norm != "rms_norm_across_heads":
+ raise NotImplementedError("Only 'rms_norm_across_heads' is supported as a valid value for `qk_norm`.")
+
+ self.head_dim = dim_head
+ self.inner_dim = dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.use_bias = bias
+ self.dropout = dropout
+ self.out_dim = query_dim
+ self.heads = heads
+
+ norm_eps = 1e-5
+ norm_elementwise_affine = True
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.norm_k = torch.nn.RMSNorm(dim_head * kv_heads, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
+ self.to_k = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_v = torch.nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
+ self.to_out = torch.nn.ModuleList([])
+ self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
+ self.to_out.append(torch.nn.Dropout(dropout))
+
+ if processor is None:
+ processor = self._default_processor_cls()
+ self.set_processor(processor)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
+ unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
+ if len(unused_kwargs) > 0:
+ logger.warning(
+ f"attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
+ )
+ kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
+
+
class LTXVideoRotaryPosEmbed(nn.Module):
def __init__(
self,
@@ -231,7 +311,7 @@ class LTXVideoTransformerBlock(nn.Module):
super().__init__()
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
- self.attn1 = Attention(
+ self.attn1 = LTXAttention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
@@ -240,11 +320,10 @@ class LTXVideoTransformerBlock(nn.Module):
cross_attention_dim=None,
out_bias=attention_out_bias,
qk_norm=qk_norm,
- processor=LTXVideoAttentionProcessor2_0(),
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
- self.attn2 = Attention(
+ self.attn2 = LTXAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
@@ -253,7 +332,6 @@ class LTXVideoTransformerBlock(nn.Module):
bias=attention_bias,
out_bias=attention_out_bias,
qk_norm=qk_norm,
- processor=LTXVideoAttentionProcessor2_0(),
)
self.ff = FeedForward(dim, activation_fn=activation_fn)
@@ -299,7 +377,9 @@ class LTXVideoTransformerBlock(nn.Module):
@maybe_allow_in_graph
-class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin):
+class LTXVideoTransformer3DModel(
+ ModelMixin, ConfigMixin, AttentionMixin, FromOriginalModelMixin, PeftAdapterMixin, CacheMixin
+):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).
@@ -328,6 +408,7 @@ class LTXVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["norm"]
+ _repeated_blocks = ["LTXVideoTransformerBlock"]
@register_to_config
def __init__(
diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py
new file mode 100644
index 0000000000..236fca690a
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py
@@ -0,0 +1,607 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
+from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
+from ..attention import FeedForward
+from ..attention_processor import Attention
+from ..cache_utils import CacheMixin
+from ..embeddings import (
+ PixArtAlphaTextProjection,
+ TimestepEmbedding,
+ get_1d_rotary_pos_embed,
+ get_1d_sincos_pos_embed_from_grid,
+)
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin, get_parameter_dtype
+from ..normalization import FP32LayerNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class SkyReelsV2AttnProcessor2_0:
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError(
+ "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
+ )
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float32).unflatten(3, (-1, 2)))
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
+ return x_out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, rotary_emb)
+ key = apply_rotary_emb(key, rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+
+ hidden_states_img = F.scaled_dot_product_attention(
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ )
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+# Copied from diffusers.models.transformers.transformer_wan.WanImageEmbedding with WanImageEmbedding -> SkyReelsV2ImageEmbedding
+class SkyReelsV2ImageEmbedding(torch.nn.Module):
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
+ super().__init__()
+
+ self.norm1 = FP32LayerNorm(in_features)
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
+ self.norm2 = FP32LayerNorm(out_features)
+ if pos_embed_seq_len is not None:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
+ else:
+ self.pos_embed = None
+
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
+ if self.pos_embed is not None:
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
+
+ hidden_states = self.norm1(encoder_hidden_states_image)
+ hidden_states = self.ff(hidden_states)
+ hidden_states = self.norm2(hidden_states)
+ return hidden_states
+
+
+class SkyReelsV2Timesteps(nn.Module):
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, output_type: str = "pt"):
+ super().__init__()
+ self.num_channels = num_channels
+ self.output_type = output_type
+ self.flip_sin_to_cos = flip_sin_to_cos
+
+ def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
+ original_shape = timesteps.shape
+ t_emb = get_1d_sincos_pos_embed_from_grid(
+ self.num_channels,
+ timesteps,
+ output_type=self.output_type,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ )
+ # Reshape back to maintain batch structure
+ if len(original_shape) > 1:
+ t_emb = t_emb.reshape(*original_shape, self.num_channels)
+ return t_emb
+
+
+class SkyReelsV2TimeTextImageEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ time_freq_dim: int,
+ time_proj_dim: int,
+ text_embed_dim: int,
+ image_embed_dim: Optional[int] = None,
+ pos_embed_seq_len: Optional[int] = None,
+ ):
+ super().__init__()
+
+ self.timesteps_proj = SkyReelsV2Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True)
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
+ self.act_fn = nn.SiLU()
+ self.time_proj = nn.Linear(dim, time_proj_dim)
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
+
+ self.image_embedder = None
+ if image_embed_dim is not None:
+ self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
+
+ def forward(
+ self,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ ):
+ timestep = self.timesteps_proj(timestep)
+
+ time_embedder_dtype = get_parameter_dtype(self.time_embedder)
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
+ timestep = timestep.to(time_embedder_dtype)
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
+ timestep_proj = self.time_proj(self.act_fn(temb))
+
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
+
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
+
+
+class SkyReelsV2RotaryPosEmbed(nn.Module):
+ def __init__(
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ ):
+ super().__init__()
+
+ self.attention_head_dim = attention_head_dim
+ self.patch_size = patch_size
+ self.max_seq_len = max_seq_len
+
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
+ t_dim = attention_head_dim - h_dim - w_dim
+
+ freqs = []
+ for dim in [t_dim, h_dim, w_dim]:
+ freq = get_1d_rotary_pos_embed(
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float32
+ )
+ freqs.append(freq)
+ self.freqs = torch.cat(freqs, dim=1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.patch_size
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
+
+ freqs = self.freqs.to(hidden_states.device)
+ freqs = freqs.split_with_sizes(
+ [
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
+ self.attention_head_dim // 6,
+ self.attention_head_dim // 6,
+ ],
+ dim=1,
+ )
+
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
+ return freqs
+
+
+class SkyReelsV2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ ffn_dim: int,
+ num_heads: int,
+ qk_norm: str = "rms_norm_across_heads",
+ cross_attn_norm: bool = False,
+ eps: float = 1e-6,
+ added_kv_proj_dim: Optional[int] = None,
+ ):
+ super().__init__()
+
+ # 1. Self-attention
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ processor=SkyReelsV2AttnProcessor2_0(),
+ )
+
+ # 2. Cross-attention
+ self.attn2 = Attention(
+ query_dim=dim,
+ heads=num_heads,
+ kv_heads=num_heads,
+ dim_head=dim // num_heads,
+ qk_norm=qk_norm,
+ eps=eps,
+ bias=True,
+ cross_attention_dim=None,
+ out_bias=True,
+ added_kv_proj_dim=added_kv_proj_dim,
+ added_proj_bias=True,
+ processor=SkyReelsV2AttnProcessor2_0(),
+ )
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
+
+ # 3. Feed-forward
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
+
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ rotary_emb: torch.Tensor,
+ attention_mask: torch.Tensor,
+ ) -> torch.Tensor:
+ if temb.dim() == 3:
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
+ elif temb.dim() == 4:
+ # For 4D temb in Diffusion Forcing framework, we assume the shape is (b, 6, f * pp_h * pp_w, inner_dim)
+ e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
+ # 1. Self-attention
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
+ attn_output = self.attn1(
+ hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
+ )
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
+ # 2. Cross-attention
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ hidden_states = hidden_states + attn_output
+
+ # 3. Feed-forward
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
+ hidden_states
+ )
+ ff_output = self.ffn(norm_hidden_states)
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
+ return hidden_states
+
+
+class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+ r"""
+ A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
+
+ Args:
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
+ num_attention_heads (`int`, defaults to `16`):
+ Fixed length for text embeddings.
+ attention_head_dim (`int`, defaults to `128`):
+ The number of channels in each head.
+ in_channels (`int`, defaults to `16`):
+ The number of channels in the input.
+ out_channels (`int`, defaults to `16`):
+ The number of channels in the output.
+ text_dim (`int`, defaults to `4096`):
+ Input dimension for text embeddings.
+ freq_dim (`int`, defaults to `256`):
+ Dimension for sinusoidal time embeddings.
+ ffn_dim (`int`, defaults to `8192`):
+ Intermediate dimension in feed-forward network.
+ num_layers (`int`, defaults to `32`):
+ The number of layers of transformer blocks to use.
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
+ Window size for local attention (-1 indicates global attention).
+ cross_attn_norm (`bool`, defaults to `True`):
+ Enable cross-attention normalization.
+ qk_norm (`str`, *optional*, defaults to `"rms_norm_across_heads"`):
+ Enable query/key normalization.
+ eps (`float`, defaults to `1e-6`):
+ Epsilon value for normalization layers.
+ inject_sample_info (`bool`, defaults to `False`):
+ Whether to inject sample information into the model.
+ image_dim (`int`, *optional*):
+ The dimension of the image embeddings.
+ added_kv_proj_dim (`int`, *optional*):
+ The dimension of the added key/value projection.
+ rope_max_seq_len (`int`, defaults to `1024`):
+ The maximum sequence length for the rotary embeddings.
+ pos_embed_seq_len (`int`, *optional*):
+ The sequence length for the positional embeddings.
+ """
+
+ _supports_gradient_checkpointing = True
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
+ _no_split_modules = ["SkyReelsV2TransformerBlock"]
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+
+ @register_to_config
+ def __init__(
+ self,
+ patch_size: Tuple[int] = (1, 2, 2),
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 128,
+ in_channels: int = 16,
+ out_channels: int = 16,
+ text_dim: int = 4096,
+ freq_dim: int = 256,
+ ffn_dim: int = 8192,
+ num_layers: int = 32,
+ cross_attn_norm: bool = True,
+ qk_norm: Optional[str] = "rms_norm_across_heads",
+ eps: float = 1e-6,
+ image_dim: Optional[int] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ rope_max_seq_len: int = 1024,
+ pos_embed_seq_len: Optional[int] = None,
+ inject_sample_info: bool = False,
+ num_frame_per_block: int = 1,
+ ) -> None:
+ super().__init__()
+
+ inner_dim = num_attention_heads * attention_head_dim
+ out_channels = out_channels or in_channels
+
+ # 1. Patch & position embedding
+ self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
+
+ # 2. Condition embeddings
+ # image_embedding_dim=1280 for I2V model
+ self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
+ dim=inner_dim,
+ time_freq_dim=freq_dim,
+ time_proj_dim=inner_dim * 6,
+ text_embed_dim=text_dim,
+ image_embed_dim=image_dim,
+ pos_embed_seq_len=pos_embed_seq_len,
+ )
+
+ # 3. Transformer blocks
+ self.blocks = nn.ModuleList(
+ [
+ SkyReelsV2TransformerBlock(
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
+
+ if inject_sample_info:
+ self.fps_embedding = nn.Embedding(2, inner_dim)
+ self.fps_projection = FeedForward(inner_dim, inner_dim * 6, mult=1, activation_fn="linear-silu")
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ enable_diffusion_forcing: bool = False,
+ fps: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
+ p_t, p_h, p_w = self.config.patch_size
+ post_patch_num_frames = num_frames // p_t
+ post_patch_height = height // p_h
+ post_patch_width = width // p_w
+
+ rotary_emb = self.rope(hidden_states)
+
+ hidden_states = self.patch_embedding(hidden_states)
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
+
+ causal_mask = None
+ if self.config.num_frame_per_block > 1:
+ block_num = post_patch_num_frames // self.config.num_frame_per_block
+ range_tensor = torch.arange(block_num, device=hidden_states.device).repeat_interleave(
+ self.config.num_frame_per_block
+ )
+ causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
+ causal_mask = causal_mask.view(post_patch_num_frames, 1, 1, post_patch_num_frames, 1, 1)
+ causal_mask = causal_mask.repeat(
+ 1, post_patch_height, post_patch_width, 1, post_patch_height, post_patch_width
+ )
+ causal_mask = causal_mask.reshape(
+ post_patch_num_frames * post_patch_height * post_patch_width,
+ post_patch_num_frames * post_patch_height * post_patch_width,
+ )
+ causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
+
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
+ timestep, encoder_hidden_states, encoder_hidden_states_image
+ )
+
+ timestep_proj = timestep_proj.unflatten(-1, (6, -1))
+
+ if encoder_hidden_states_image is not None:
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
+
+ if self.config.inject_sample_info:
+ fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
+
+ fps_emb = self.fps_embedding(fps)
+ if enable_diffusion_forcing:
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1)).repeat(
+ timestep.shape[1], 1, 1
+ )
+ else:
+ timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, -1))
+
+ if enable_diffusion_forcing:
+ b, f = timestep.shape
+ temb = temb.view(b, f, 1, 1, -1)
+ timestep_proj = timestep_proj.view(b, f, 1, 1, 6, -1) # (b, f, 1, 1, 6, inner_dim)
+ temb = temb.repeat(1, 1, post_patch_height, post_patch_width, 1).flatten(1, 3)
+ timestep_proj = timestep_proj.repeat(1, 1, post_patch_height, post_patch_width, 1, 1).flatten(
+ 1, 3
+ ) # (b, f, pp_h, pp_w, 6, inner_dim) -> (b, f * pp_h * pp_w, 6, inner_dim)
+ timestep_proj = timestep_proj.transpose(1, 2).contiguous() # (b, 6, f * pp_h * pp_w, inner_dim)
+
+ # 4. Transformer blocks
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for block in self.blocks:
+ hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ causal_mask,
+ )
+ else:
+ for block in self.blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states,
+ timestep_proj,
+ rotary_emb,
+ causal_mask,
+ )
+
+ if temb.dim() == 2:
+ # If temb is 2D, we assume it has time 1-D time embedding values for each batch.
+ # For models:
+ # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ elif temb.dim() == 3:
+ # If temb is 3D, we assume it has 2-D time embedding values for each batch.
+ # Each time embedding tensor includes values for each latent frame; thus Diffusion Forcing.
+ # For models:
+ # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ shift, scale = (self.scale_shift_table.unsqueeze(2) + temb.unsqueeze(1)).chunk(2, dim=1)
+ shift, scale = shift.squeeze(1), scale.squeeze(1)
+
+ # Move the shift and scale tensors to the same device as hidden_states.
+ # When using multi-GPU inference via accelerate these will be on the
+ # first device rather than the last device, which hidden_states ends up
+ # on.
+ shift = shift.to(hidden_states.device)
+ scale = scale.to(hidden_states.device)
+
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
+
+ hidden_states = self.proj_out(hidden_states)
+
+ hidden_states = hidden_states.reshape(
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
+ )
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+ def _set_ar_attention(self, causal_block_size: int):
+ self.register_to_config(num_frame_per_block=causal_block_size)
diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py
index baa0ede418..8a18ea5f3e 100644
--- a/src/diffusers/models/transformers/transformer_wan.py
+++ b/src/diffusers/models/transformers/transformer_wan.py
@@ -21,9 +21,10 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
-from ..attention import FeedForward
-from ..attention_processor import Attention
+from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from ...utils.torch_utils import maybe_allow_in_graph
+from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
+from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
@@ -34,18 +35,51 @@ from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
-class WanAttnProcessor2_0:
+def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
+ # encoder_hidden_states is only passed for cross-attention
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ if attn.fused_projections:
+ if attn.cross_attention_dim_head is None:
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
+ query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
+ else:
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
+ query = attn.to_q(hidden_states)
+ key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
+ else:
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+ return query, key, value
+
+
+def _get_added_kv_projections(attn: "WanAttention", encoder_hidden_states_img: torch.Tensor):
+ if attn.fused_projections:
+ key_img, value_img = attn.to_added_kv(encoder_hidden_states_img).chunk(2, dim=-1)
+ else:
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
+ return key_img, value_img
+
+
+class WanAttnProcessor:
+ _attention_backend = None
+
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
- raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
+ raise ImportError(
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
+ )
def __call__(
self,
- attn: Attention,
+ attn: "WanAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
- rotary_emb: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
@@ -53,53 +87,65 @@ class WanAttnProcessor2_0:
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
- if encoder_hidden_states is None:
- encoder_hidden_states = hidden_states
- query = attn.to_q(hidden_states)
- key = attn.to_k(encoder_hidden_states)
- value = attn.to_v(encoder_hidden_states)
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
- if attn.norm_q is not None:
- query = attn.norm_q(query)
- if attn.norm_k is not None:
- key = attn.norm_k(key)
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
- def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
- dtype = torch.float32 if hidden_states.device.type == "mps" else torch.float64
- x_rotated = torch.view_as_complex(hidden_states.to(dtype).unflatten(3, (-1, 2)))
- x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
- return x_out.type_as(hidden_states)
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
- query = apply_rotary_emb(query, rotary_emb)
- key = apply_rotary_emb(key, rotary_emb)
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
- key_img = attn.add_k_proj(encoder_hidden_states_img)
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
- value_img = attn.add_v_proj(encoder_hidden_states_img)
- key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
- value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
- hidden_states_img = F.scaled_dot_product_attention(
- query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
)
- hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
+ hidden_states_img = hidden_states_img.flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
- hidden_states = F.scaled_dot_product_attention(
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
)
- hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
+ hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
if hidden_states_img is not None:
@@ -110,6 +156,119 @@ class WanAttnProcessor2_0:
return hidden_states
+class WanAttnProcessor2_0:
+ def __new__(cls, *args, **kwargs):
+ deprecation_message = (
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
+ "Please use WanAttnProcessor instead. "
+ )
+ deprecate("WanAttnProcessor2_0", "1.0.0", deprecation_message, standard_warn=False)
+ return WanAttnProcessor(*args, **kwargs)
+
+
+class WanAttention(torch.nn.Module, AttentionModuleMixin):
+ _default_processor_cls = WanAttnProcessor
+ _available_processors = [WanAttnProcessor]
+
+ def __init__(
+ self,
+ dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ eps: float = 1e-5,
+ dropout: float = 0.0,
+ added_kv_proj_dim: Optional[int] = None,
+ cross_attention_dim_head: Optional[int] = None,
+ processor=None,
+ ):
+ super().__init__()
+
+ self.inner_dim = dim_head * heads
+ self.heads = heads
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.cross_attention_dim_head = cross_attention_dim_head
+ self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
+
+ self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
+ self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
+ self.to_out = torch.nn.ModuleList(
+ [
+ torch.nn.Linear(self.inner_dim, dim, bias=True),
+ torch.nn.Dropout(dropout),
+ ]
+ )
+ self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+ self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
+
+ self.add_k_proj = self.add_v_proj = None
+ if added_kv_proj_dim is not None:
+ self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
+ self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
+
+ self.set_processor(processor)
+
+ def fuse_projections(self):
+ if getattr(self, "fused_projections", False):
+ return
+
+ if self.cross_attention_dim_head is None:
+ concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_qkv = nn.Linear(in_features, out_features, bias=True)
+ self.to_qkv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+ else:
+ concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
+ concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ if self.added_kv_proj_dim is not None:
+ concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
+ concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
+ out_features, in_features = concatenated_weights.shape
+ with torch.device("meta"):
+ self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
+ self.to_added_kv.load_state_dict(
+ {"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
+ )
+
+ self.fused_projections = True
+
+ @torch.no_grad()
+ def unfuse_projections(self):
+ if not getattr(self, "fused_projections", False):
+ return
+
+ if hasattr(self, "to_qkv"):
+ delattr(self, "to_qkv")
+ if hasattr(self, "to_kv"):
+ delattr(self, "to_kv")
+ if hasattr(self, "to_added_kv"):
+ delattr(self, "to_added_kv")
+
+ self.fused_projections = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, rotary_emb, **kwargs)
+
+
class WanImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
@@ -161,8 +320,11 @@ class WanTimeTextImageEmbedding(nn.Module):
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
+ timestep_seq_len: Optional[int] = None,
):
timestep = self.timesteps_proj(timestep)
+ if timestep_seq_len is not None:
+ timestep = timestep.unflatten(0, (1, timestep_seq_len))
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
@@ -179,7 +341,11 @@ class WanTimeTextImageEmbedding(nn.Module):
class WanRotaryPosEmbed(nn.Module):
def __init__(
- self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
+ self,
+ attention_head_dim: int,
+ patch_size: Tuple[int, int, int],
+ max_seq_len: int,
+ theta: float = 10000.0,
):
super().__init__()
@@ -189,38 +355,55 @@ class WanRotaryPosEmbed(nn.Module):
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
-
- freqs = []
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+
+ freqs_cos = []
+ freqs_sin = []
+
for dim in [t_dim, h_dim, w_dim]:
- freq = get_1d_rotary_pos_embed(
- dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=freqs_dtype
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
+ dim,
+ max_seq_len,
+ theta,
+ use_real=True,
+ repeat_interleave_real=True,
+ freqs_dtype=freqs_dtype,
)
- freqs.append(freq)
- self.freqs = torch.cat(freqs, dim=1)
+ freqs_cos.append(freq_cos)
+ freqs_sin.append(freq_sin)
+
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
- freqs = self.freqs.to(hidden_states.device)
- freqs = freqs.split_with_sizes(
- [
- self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
- self.attention_head_dim // 6,
- self.attention_head_dim // 6,
- ],
- dim=1,
- )
+ split_sizes = [
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
+ self.attention_head_dim // 3,
+ self.attention_head_dim // 3,
+ ]
- freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
- freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
- freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
- return freqs
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
+
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
+
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, ppf * pph * ppw, 1, -1)
+
+ return freqs_cos, freqs_sin
+@maybe_allow_in_graph
class WanTransformerBlock(nn.Module):
def __init__(
self,
@@ -236,33 +419,24 @@ class WanTransformerBlock(nn.Module):
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
- self.attn1 = Attention(
- query_dim=dim,
+ self.attn1 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
- processor=WanAttnProcessor2_0(),
+ cross_attention_dim_head=None,
+ processor=WanAttnProcessor(),
)
# 2. Cross-attention
- self.attn2 = Attention(
- query_dim=dim,
+ self.attn2 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
- added_proj_bias=True,
- processor=WanAttnProcessor2_0(),
+ cross_attention_dim_head=dim // num_heads,
+ processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -279,18 +453,32 @@ class WanTransformerBlock(nn.Module):
temb: torch.Tensor,
rotary_emb: torch.Tensor,
) -> torch.Tensor:
- shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
- self.scale_shift_table + temb.float()
- ).chunk(6, dim=1)
+ if temb.ndim == 4:
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table.unsqueeze(0) + temb.float()
+ ).chunk(6, dim=2)
+ # batch_size, seq_len, 1, inner_dim
+ shift_msa = shift_msa.squeeze(2)
+ scale_msa = scale_msa.squeeze(2)
+ gate_msa = gate_msa.squeeze(2)
+ c_shift_msa = c_shift_msa.squeeze(2)
+ c_scale_msa = c_scale_msa.squeeze(2)
+ c_gate_msa = c_gate_msa.squeeze(2)
+ else:
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
+ self.scale_shift_table + temb.float()
+ ).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
@@ -303,7 +491,9 @@ class WanTransformerBlock(nn.Module):
return hidden_states
-class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
+class WanTransformer3DModel(
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
+):
r"""
A Transformer model for video-like data used in the Wan model.
@@ -345,6 +535,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
+ _repeated_blocks = ["WanTransformerBlock"]
@register_to_config
def __init__(
@@ -438,10 +629,22 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states = self.patch_embedding(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
+ if timestep.ndim == 2:
+ ts_seq_len = timestep.shape[1]
+ timestep = timestep.flatten() # batch_size * seq_len
+ else:
+ ts_seq_len = None
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
- timestep, encoder_hidden_states, encoder_hidden_states_image
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
)
- timestep_proj = timestep_proj.unflatten(1, (6, -1))
+ if ts_seq_len is not None:
+ # batch_size, seq_len, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
+ else:
+ # batch_size, 6, inner_dim
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
@@ -457,7 +660,14 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
# 5. Output norm, projection & unpatchify
- shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
+ if temb.ndim == 3:
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
+ shift = shift.squeeze(2)
+ scale = scale.squeeze(2)
+ else:
+ # batch_size, inner_dim
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
diff --git a/src/diffusers/models/transformers/transformer_wan_vace.py b/src/diffusers/models/transformers/transformer_wan_vace.py
index 1a6f2af59a..e039d36219 100644
--- a/src/diffusers/models/transformers/transformer_wan_vace.py
+++ b/src/diffusers/models/transformers/transformer_wan_vace.py
@@ -22,12 +22,17 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
-from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
-from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
+from .transformer_wan import (
+ WanAttention,
+ WanAttnProcessor,
+ WanRotaryPosEmbed,
+ WanTimeTextImageEmbedding,
+ WanTransformerBlock,
+)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,33 +60,22 @@ class WanVACETransformerBlock(nn.Module):
# 2. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
- self.attn1 = Attention(
- query_dim=dim,
+ self.attn1 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
- processor=WanAttnProcessor2_0(),
+ processor=WanAttnProcessor(),
)
# 3. Cross-attention
- self.attn2 = Attention(
- query_dim=dim,
+ self.attn2 = WanAttention(
+ dim=dim,
heads=num_heads,
- kv_heads=num_heads,
dim_head=dim // num_heads,
- qk_norm=qk_norm,
eps=eps,
- bias=True,
- cross_attention_dim=None,
- out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
- added_proj_bias=True,
- processor=WanAttnProcessor2_0(),
+ processor=WanAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
@@ -116,12 +110,12 @@ class WanVACETransformerBlock(nn.Module):
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
control_hidden_states
)
- attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
+ attn_output = self.attn1(norm_hidden_states, None, None, rotary_emb)
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
# 2. Cross-attention
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
- attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states, None, None)
control_hidden_states = control_hidden_states + attn_output
# 3. Feed-forward
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 0cf5133c54..736deb28c3 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -165,8 +165,9 @@ class UNet2DConditionModel(
"""
_supports_gradient_checkpointing = True
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"]
_skip_layerwise_casting_patterns = ["norm"]
+ _repeated_blocks = ["BasicTransformerBlock"]
@register_to_config
def __init__(
diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py
new file mode 100644
index 0000000000..e0f2e31388
--- /dev/null
+++ b/src/diffusers/modular_pipelines/__init__.py
@@ -0,0 +1,88 @@
+from typing import TYPE_CHECKING
+
+from ..utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+# These modules contain pipelines from multiple libraries/frameworks
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ..utils import dummy_pt_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
+else:
+ _import_structure["modular_pipeline"] = [
+ "ModularPipelineBlocks",
+ "ModularPipeline",
+ "PipelineBlock",
+ "AutoPipelineBlocks",
+ "SequentialPipelineBlocks",
+ "LoopSequentialPipelineBlocks",
+ "PipelineState",
+ "BlockState",
+ ]
+ _import_structure["modular_pipeline_utils"] = [
+ "ComponentSpec",
+ "ConfigSpec",
+ "InputParam",
+ "OutputParam",
+ "InsertableDict",
+ ]
+ _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
+ _import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
+ _import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
+ _import_structure["components_manager"] = ["ComponentsManager"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not is_torch_available():
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ..utils.dummy_pt_objects import * # noqa F403
+ else:
+ from .components_manager import ComponentsManager
+ from .flux import FluxAutoBlocks, FluxModularPipeline
+ from .modular_pipeline import (
+ AutoPipelineBlocks,
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ ModularPipeline,
+ ModularPipelineBlocks,
+ PipelineBlock,
+ PipelineState,
+ SequentialPipelineBlocks,
+ )
+ from .modular_pipeline_utils import (
+ ComponentSpec,
+ ConfigSpec,
+ InputParam,
+ InsertableDict,
+ OutputParam,
+ )
+ from .stable_diffusion_xl import (
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLModularPipeline,
+ )
+ from .wan import WanAutoBlocks, WanModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py
new file mode 100644
index 0000000000..f48a227e2e
--- /dev/null
+++ b/src/diffusers/modular_pipelines/components_manager.py
@@ -0,0 +1,1068 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import time
+from collections import OrderedDict
+from itertools import combinations
+from typing import Any, Dict, List, Optional, Union
+
+import torch
+
+from ..hooks import ModelHook
+from ..utils import (
+ is_accelerate_available,
+ logging,
+)
+
+
+if is_accelerate_available():
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+ from accelerate.state import PartialState
+ from accelerate.utils import send_to_device
+ from accelerate.utils.memory import clear_device_cache
+ from accelerate.utils.modeling import convert_file_size_to_int
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class CustomOffloadHook(ModelHook):
+ """
+ A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
+ on the given device. Optionally offloads other models to the CPU before the forward pass is called.
+
+ Args:
+ execution_device(`str`, `int` or `torch.device`, *optional*):
+ The device on which the model should be executed. Will default to the MPS device if it's available, then
+ GPU 0 if there is a GPU, and finally to the CPU.
+ """
+
+ no_grad = False
+
+ def __init__(
+ self,
+ execution_device: Optional[Union[str, int, torch.device]] = None,
+ other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
+ offload_strategy: Optional["AutoOffloadStrategy"] = None,
+ ):
+ self.execution_device = execution_device if execution_device is not None else PartialState().default_device
+ self.other_hooks = other_hooks
+ self.offload_strategy = offload_strategy
+ self.model_id = None
+
+ def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
+ self.offload_strategy = offload_strategy
+
+ def add_other_hook(self, hook: "UserCustomOffloadHook"):
+ """
+ Add a hook to the list of hooks to consider for offloading.
+ """
+ if self.other_hooks is None:
+ self.other_hooks = []
+ self.other_hooks.append(hook)
+
+ def init_hook(self, module):
+ return module.to("cpu")
+
+ def pre_forward(self, module, *args, **kwargs):
+ if module.device != self.execution_device:
+ if self.other_hooks is not None:
+ hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
+ # offload all other hooks
+ start_time = time.perf_counter()
+ if self.offload_strategy is not None:
+ hooks_to_offload = self.offload_strategy(
+ hooks=hooks_to_offload,
+ model_id=self.model_id,
+ model=module,
+ execution_device=self.execution_device,
+ )
+ end_time = time.perf_counter()
+ logger.info(
+ f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
+ )
+
+ for hook in hooks_to_offload:
+ logger.info(
+ f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
+ )
+ hook.offload()
+
+ if hooks_to_offload:
+ clear_device_cache()
+ module.to(self.execution_device)
+ return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
+
+
+class UserCustomOffloadHook:
+ """
+ A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
+ the hook or remove it entirely.
+ """
+
+ def __init__(self, model_id, model, hook):
+ self.model_id = model_id
+ self.model = model
+ self.hook = hook
+
+ def offload(self):
+ self.hook.init_hook(self.model)
+
+ def attach(self):
+ add_hook_to_module(self.model, self.hook)
+ self.hook.model_id = self.model_id
+
+ def remove(self):
+ remove_hook_from_module(self.model)
+ self.hook.model_id = None
+
+ def add_other_hook(self, hook: "UserCustomOffloadHook"):
+ self.hook.add_other_hook(hook)
+
+
+def custom_offload_with_hook(
+ model_id: str,
+ model: torch.nn.Module,
+ execution_device: Union[str, int, torch.device] = None,
+ offload_strategy: Optional["AutoOffloadStrategy"] = None,
+):
+ hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
+ user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
+ user_hook.attach()
+ return user_hook
+
+
+# this is the class that user can customize to implement their own offload strategy
+class AutoOffloadStrategy:
+ """
+ Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
+ the available memory on the device.
+ """
+
+ # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device
+ # the actual memory usage would be higher. But it's simpler this way, and can be tested
+ def __init__(self, memory_reserve_margin="3GB"):
+ self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
+
+ def __call__(self, hooks, model_id, model, execution_device):
+ if len(hooks) == 0:
+ return []
+
+ current_module_size = model.get_memory_footprint()
+
+ mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
+ mem_on_device = mem_on_device - self.memory_reserve_margin
+ if current_module_size < mem_on_device:
+ return []
+
+ min_memory_offload = current_module_size - mem_on_device
+ logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
+
+ # exlucde models that's not currently loaded on the device
+ module_sizes = dict(
+ sorted(
+ {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+ )
+
+ # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often
+ def search_best_candidate(module_sizes, min_memory_offload):
+ """
+ search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
+ minimum memory offload size. the combination of models should add up to the smallest modulesize that is
+ larger than `min_memory_offload`
+ """
+ model_ids = list(module_sizes.keys())
+ best_candidate = None
+ best_size = float("inf")
+ for r in range(1, len(model_ids) + 1):
+ for candidate_model_ids in combinations(model_ids, r):
+ candidate_size = sum(
+ module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
+ )
+ if candidate_size < min_memory_offload:
+ continue
+ else:
+ if best_candidate is None or candidate_size < best_size:
+ best_candidate = candidate_model_ids
+ best_size = candidate_size
+
+ return best_candidate
+
+ best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
+
+ if best_offload_model_ids is None:
+ # if no combination is found, meaning that we cannot meet the memory requirement, offload all models
+ logger.warning("no combination of models to offload to cpu is found, offloading all models")
+ hooks_to_offload = hooks
+ else:
+ hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
+
+ return hooks_to_offload
+
+
+# utils for display component info in a readable format
+# TODO: move to a different file
+def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
+ """Summarizes a dictionary by finding common prefixes that share the same value.
+
+ For a dictionary with dot-separated keys like: {
+ 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
+ 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
+ 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
+ }
+
+ Returns a dictionary where keys are the shortest common prefixes and values are their shared values: {
+ 'down_blocks': [0.6], 'up_blocks': [0.3]
+ }
+ """
+ # First group by values - convert lists to tuples to make them hashable
+ value_to_keys = {}
+ for key, value in d.items():
+ value_tuple = tuple(value) if isinstance(value, list) else value
+ if value_tuple not in value_to_keys:
+ value_to_keys[value_tuple] = []
+ value_to_keys[value_tuple].append(key)
+
+ def find_common_prefix(keys: List[str]) -> str:
+ """Find the shortest common prefix among a list of dot-separated keys."""
+ if not keys:
+ return ""
+ if len(keys) == 1:
+ return keys[0]
+
+ # Split all keys into parts
+ key_parts = [k.split(".") for k in keys]
+
+ # Find how many initial parts are common
+ common_length = 0
+ for parts in zip(*key_parts):
+ if len(set(parts)) == 1: # All parts at this position are the same
+ common_length += 1
+ else:
+ break
+
+ if common_length == 0:
+ return ""
+
+ # Return the common prefix
+ return ".".join(key_parts[0][:common_length])
+
+ # Create summary by finding common prefixes for each value group
+ summary = {}
+ for value_tuple, keys in value_to_keys.items():
+ prefix = find_common_prefix(keys)
+ if prefix: # Only add if we found a common prefix
+ # Convert tuple back to list if it was originally a list
+ value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
+ summary[prefix] = value
+ else:
+ summary[""] = value # Use empty string if no common prefix
+
+ return summary
+
+
+class ComponentsManager:
+ """
+ A central registry and management system for model components across multiple pipelines.
+
+ [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text
+ encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
+ management, and component organization.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Example:
+ ```python
+ from diffusers import ComponentsManager
+
+ # Create a components manager
+ cm = ComponentsManager()
+
+ # Add components
+ cm.add("unet", unet_model, collection="sdxl")
+ cm.add("vae", vae_model, collection="sdxl")
+
+ # Enable auto offloading
+ cm.enable_auto_cpu_offload(device="cuda")
+
+ # Retrieve components
+ unet = cm.get_one(name="unet", collection="sdxl")
+ ```
+ """
+
+ _available_info_fields = [
+ "model_id",
+ "added_time",
+ "collection",
+ "class_name",
+ "size_gb",
+ "adapters",
+ "has_hook",
+ "execution_device",
+ "ip_adapter",
+ ]
+
+ def __init__(self):
+ self.components = OrderedDict()
+ # YiYi TODO: can remove once confirm we don't need this in mellon
+ self.added_time = OrderedDict() # Store when components were added
+ self.collections = OrderedDict() # collection_name -> set of component_names
+ self.model_hooks = None
+ self._auto_offload_enabled = False
+
+ def _lookup_ids(
+ self,
+ name: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ components: Optional[OrderedDict] = None,
+ ):
+ """
+ Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of
+ component_ids
+ """
+ if components is None:
+ components = self.components
+
+ if name:
+ ids_by_name = set()
+ for component_id, component in components.items():
+ comp_name = self._id_to_name(component_id)
+ if comp_name == name:
+ ids_by_name.add(component_id)
+ else:
+ ids_by_name = set(components.keys())
+ if collection:
+ ids_by_collection = set()
+ for component_id, component in components.items():
+ if component_id in self.collections[collection]:
+ ids_by_collection.add(component_id)
+ else:
+ ids_by_collection = set(components.keys())
+ if load_id:
+ ids_by_load_id = set()
+ for name, component in components.items():
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
+ ids_by_load_id.add(name)
+ else:
+ ids_by_load_id = set(components.keys())
+
+ ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
+ return ids
+
+ @staticmethod
+ def _id_to_name(component_id: str):
+ return "_".join(component_id.split("_")[:-1])
+
+ def add(self, name: str, component: Any, collection: Optional[str] = None):
+ """
+ Add a component to the ComponentsManager.
+
+ Args:
+ name (str): The name of the component
+ component (Any): The component to add
+ collection (Optional[str]): The collection to add the component to
+
+ Returns:
+ str: The unique component ID, which is generated as "{name}_{id(component)}" where
+ id(component) is Python's built-in unique identifier for the object
+ """
+ component_id = f"{name}_{id(component)}"
+ is_new_component = True
+
+ # check for duplicated components
+ for comp_id, comp in self.components.items():
+ if comp == component:
+ comp_name = self._id_to_name(comp_id)
+ if comp_name == name:
+ logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
+ component_id = comp_id
+ is_new_component = False
+ break
+ else:
+ logger.warning(
+ f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
+ f"To remove a duplicate, call `components_manager.remove('')`."
+ )
+
+ # check for duplicated load_id and warn (we do not delete for you)
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
+ components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
+
+ if components_with_same_load_id:
+ existing = ", ".join(components_with_same_load_id)
+ logger.warning(
+ f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
+ f"To remove a duplicate, call `components_manager.remove('')`."
+ )
+
+ # add component to components manager
+ self.components[component_id] = component
+ self.added_time[component_id] = time.time()
+
+ if collection:
+ if collection not in self.collections:
+ self.collections[collection] = set()
+ if component_id not in self.collections[collection]:
+ comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
+ for comp_id in comp_ids_in_collection:
+ logger.warning(
+ f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
+ )
+ # remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
+ self.remove_from_collection(comp_id, collection)
+
+ self.collections[collection].add(component_id)
+ logger.info(
+ f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
+ )
+ else:
+ logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")
+
+ if self._auto_offload_enabled and is_new_component:
+ self.enable_auto_cpu_offload(self._auto_offload_device)
+
+ return component_id
+
+ def remove_from_collection(self, component_id: str, collection: str):
+ """
+ Remove a component from a collection.
+ """
+ if collection not in self.collections:
+ logger.warning(f"Collection '{collection}' not found in ComponentsManager")
+ return
+ if component_id not in self.collections[collection]:
+ logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
+ return
+ # remove from the collection
+ self.collections[collection].remove(component_id)
+ # check if this component is in any other collection
+ comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
+ if not comp_colls: # only if no other collection contains this component, remove it
+ logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
+ self.remove(component_id)
+
+ def remove(self, component_id: str = None):
+ """
+ Remove a component from the ComponentsManager.
+
+ Args:
+ component_id (str): The ID of the component to remove
+ """
+ if component_id not in self.components:
+ logger.warning(f"Component '{component_id}' not found in ComponentsManager")
+ return
+
+ component = self.components.pop(component_id)
+ self.added_time.pop(component_id)
+
+ for collection in self.collections:
+ if component_id in self.collections[collection]:
+ self.collections[collection].remove(component_id)
+
+ if self._auto_offload_enabled:
+ self.enable_auto_cpu_offload(self._auto_offload_device)
+ else:
+ if isinstance(component, torch.nn.Module):
+ component.to("cpu")
+ del component
+ import gc
+
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ # YiYi TODO: rename to search_components for now, may remove this method
+ def search_components(
+ self,
+ names: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ return_dict_with_names: bool = True,
+ ):
+ """
+ Search components by name with simple pattern matching. Optionally filter by collection or load_id.
+
+ Args:
+ names: Component name(s) or pattern(s)
+ Patterns:
+ - "unet" : match any component with base name "unet" (e.g., unet_123abc)
+ - "!unet" : everything except components with base name "unet"
+ - "unet*" : anything with base name starting with "unet"
+ - "!unet*" : anything with base name NOT starting with "unet"
+ - "*unet*" : anything with base name containing "unet"
+ - "!*unet*" : anything with base name NOT containing "unet"
+ - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
+ - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
+ - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
+ collection: Optional collection to filter by
+ load_id: Optional load_id to filter by
+ return_dict_with_names:
+ If True, returns a dictionary with component names as keys, throw an error if
+ multiple components with the same name are found If False, returns a dictionary
+ with component IDs as keys
+
+ Returns:
+ Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping
+ component IDs to components if return_dict_with_names=False
+ """
+
+ # select components based on collection and load_id filters
+ selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
+ components = {k: self.components[k] for k in selected_ids}
+
+ def get_return_dict(components, return_dict_with_names):
+ """
+ Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary
+ mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component
+ names are found when return_dict_with_names=True
+ """
+ if return_dict_with_names:
+ dict_to_return = {}
+ for comp_id, comp in components.items():
+ comp_name = self._id_to_name(comp_id)
+ if comp_name in dict_to_return:
+ raise ValueError(
+ f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
+ )
+ dict_to_return[comp_name] = comp
+ return dict_to_return
+ else:
+ return components
+
+ # if no names are provided, return the filtered components as it is
+ if names is None:
+ return get_return_dict(components, return_dict_with_names)
+
+ # if names is not a string, raise an error
+ elif not isinstance(names, str):
+ raise ValueError(f"Invalid type for `names: {type(names)}, only support string")
+
+ # Create mapping from component_id to base_name for components to be used for pattern matching
+ base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()}
+
+ # Helper function to check if a component matches a pattern based on its base name
+ def matches_pattern(component_id, pattern, exact_match=False):
+ """
+ Helper function to check if a component matches a pattern based on its base name.
+
+ Args:
+ component_id: The component ID to check
+ pattern: The pattern to match against
+ exact_match: If True, only exact matches to base_name are considered
+ """
+ base_name = base_names[component_id]
+
+ # Exact match with base name
+ if exact_match:
+ return pattern == base_name
+
+ # Prefix match (ends with *)
+ elif pattern.endswith("*"):
+ prefix = pattern[:-1]
+ return base_name.startswith(prefix)
+
+ # Contains match (starts with *)
+ elif pattern.startswith("*"):
+ search = pattern[1:-1] if pattern.endswith("*") else pattern[1:]
+ return search in base_name
+
+ # Exact match (no wildcards)
+ else:
+ return pattern == base_name
+
+ # Check if this is a "not" pattern
+ is_not_pattern = names.startswith("!")
+ if is_not_pattern:
+ names = names[1:] # Remove the ! prefix
+
+ # Handle OR patterns (containing |)
+ if "|" in names:
+ terms = names.split("|")
+ matches = {}
+
+ for comp_id, comp in components.items():
+ # For OR patterns with exact names (no wildcards), we do exact matching on base names
+ exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms)
+
+ # Check if any of the terms match this component
+ should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
+
+ # Flip the decision if this is a NOT pattern
+ if is_not_pattern:
+ should_include = not should_include
+
+ if should_include:
+ matches[comp_id] = comp
+
+ log_msg = "NOT " if is_not_pattern else ""
+ match_type = "exactly matching" if exact_match else "matching any of patterns"
+ logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
+
+ # Try exact match with a base name
+ elif any(names == base_name for base_name in base_names.values()):
+ # Find all components with this base name
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (base_names[comp_id] == names) != is_not_pattern
+ }
+
+ if is_not_pattern:
+ logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
+
+ # Prefix match (ends with *)
+ elif names.endswith("*"):
+ prefix = names[:-1]
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if base_names[comp_id].startswith(prefix) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
+
+ # Contains match (starts with *)
+ elif names.startswith("*"):
+ search = names[1:-1] if names.endswith("*") else names[1:]
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (search in base_names[comp_id]) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
+
+ # Substring match (no wildcards, but not an exact component name)
+ elif any(names in base_name for base_name in base_names.values()):
+ matches = {
+ comp_id: comp
+ for comp_id, comp in components.items()
+ if (names in base_names[comp_id]) != is_not_pattern
+ }
+ if is_not_pattern:
+ logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
+ else:
+ logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
+
+ else:
+ raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
+
+ if not matches:
+ raise ValueError(f"No components found matching pattern '{names}'")
+
+ return get_return_dict(matches, return_dict_with_names)
+
+ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda", memory_reserve_margin="3GB"):
+ """
+ Enable automatic CPU offloading for all components.
+
+ The algorithm works as follows:
+ 1. All models start on CPU by default
+ 2. When a model's forward pass is called, it's moved to the execution device
+ 3. If there's insufficient memory, other models on the device are moved back to CPU
+ 4. The system tries to offload the smallest combination of models that frees enough memory
+ 5. Models stay on the execution device until another model needs memory and forces them off
+
+ Args:
+ device (Union[str, int, torch.device]): The execution device where models are moved for forward passes
+ memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of
+ memory to keep free on the device to avoid running out of memory during model
+ execution (e.g., for intermediate activations, gradients, etc.)
+ """
+ if not is_accelerate_available():
+ raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
+
+ for name, component in self.components.items():
+ if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
+ remove_hook_from_module(component, recurse=True)
+
+ self.disable_auto_cpu_offload()
+ offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
+ device = torch.device(device)
+ if device.index is None:
+ device = torch.device(f"{device.type}:{0}")
+ all_hooks = []
+ for name, component in self.components.items():
+ if isinstance(component, torch.nn.Module):
+ hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
+ all_hooks.append(hook)
+
+ for hook in all_hooks:
+ other_hooks = [h for h in all_hooks if h is not hook]
+ for other_hook in other_hooks:
+ if other_hook.hook.execution_device == hook.hook.execution_device:
+ hook.add_other_hook(other_hook)
+
+ self.model_hooks = all_hooks
+ self._auto_offload_enabled = True
+ self._auto_offload_device = device
+
+ def disable_auto_cpu_offload(self):
+ """
+ Disable automatic CPU offloading for all components.
+ """
+ if self.model_hooks is None:
+ self._auto_offload_enabled = False
+ return
+
+ for hook in self.model_hooks:
+ hook.offload()
+ hook.remove()
+ if self.model_hooks:
+ clear_device_cache()
+ self.model_hooks = None
+ self._auto_offload_enabled = False
+
+ # YiYi TODO: (1) add quantization info
+ def get_model_info(
+ self,
+ component_id: str,
+ fields: Optional[Union[str, List[str]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ """Get comprehensive information about a component.
+
+ Args:
+ component_id (str): Name of the component to get info for
+ fields (Optional[Union[str, List[str]]]):
+ Field(s) to return. Can be a string for single field or list of fields. If None, uses the
+ available_info_fields setting.
+
+ Returns:
+ Dictionary containing requested component metadata. If fields is specified, returns only those fields.
+ Otherwise, returns all fields.
+ """
+ if component_id not in self.components:
+ raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
+
+ component = self.components[component_id]
+
+ # Validate fields if specified
+ if fields is not None:
+ if isinstance(fields, str):
+ fields = [fields]
+ for field in fields:
+ if field not in self._available_info_fields:
+ raise ValueError(f"Field '{field}' not found in available_info_fields")
+
+ # Build complete info dict first
+ info = {
+ "model_id": component_id,
+ "added_time": self.added_time[component_id],
+ "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps])
+ or None,
+ }
+
+ # Additional info for torch.nn.Module components
+ if isinstance(component, torch.nn.Module):
+ # Check for hook information
+ has_hook = hasattr(component, "_hf_hook")
+ execution_device = None
+ if has_hook and hasattr(component._hf_hook, "execution_device"):
+ execution_device = component._hf_hook.execution_device
+
+ info.update(
+ {
+ "class_name": component.__class__.__name__,
+ "size_gb": component.get_memory_footprint() / (1024**3),
+ "adapters": None, # Default to None
+ "has_hook": has_hook,
+ "execution_device": execution_device,
+ }
+ )
+
+ # Get adapters if applicable
+ if hasattr(component, "peft_config"):
+ info["adapters"] = list(component.peft_config.keys())
+
+ # Check for IP-Adapter scales
+ if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
+ processors = copy.deepcopy(component.attn_processors)
+ # First check if any processor is an IP-Adapter
+ processor_types = [v.__class__.__name__ for v in processors.values()]
+ if any("IPAdapter" in ptype for ptype in processor_types):
+ # Then get scales only from IP-Adapter processors
+ scales = {
+ k: v.scale
+ for k, v in processors.items()
+ if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
+ }
+ if scales:
+ info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
+
+ # If fields specified, filter info
+ if fields is not None:
+ return {k: v for k, v in info.items() if k in fields}
+ else:
+ return info
+
+ # YiYi TODO: (1) add display fields, allow user to set which fields to display in the comnponents table
+ def __repr__(self):
+ # Handle empty components case
+ if not self.components:
+ return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50
+
+ # Extract load_id if available
+ def get_load_id(component):
+ if hasattr(component, "_diffusers_load_id"):
+ return component._diffusers_load_id
+ return "N/A"
+
+ # Format device info compactly
+ def format_device(component, info):
+ if not info["has_hook"]:
+ return str(getattr(component, "device", "N/A"))
+ else:
+ device = str(getattr(component, "device", "N/A"))
+ exec_device = str(info["execution_device"] or "N/A")
+ return f"{device}({exec_device})"
+
+ # Get max length of load_ids for models
+ load_ids = [
+ get_load_id(component)
+ for component in self.components.values()
+ if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
+ ]
+ max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
+
+ # Get all collections for each component
+ component_collections = {}
+ for name in self.components.keys():
+ component_collections[name] = []
+ for coll, comps in self.collections.items():
+ if name in comps:
+ component_collections[name].append(coll)
+ if not component_collections[name]:
+ component_collections[name] = ["N/A"]
+
+ # Find the maximum collection name length
+ all_collections = [coll for colls in component_collections.values() for coll in colls]
+ max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
+
+ col_widths = {
+ "id": max(15, max(len(name) for name in self.components.keys())),
+ "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
+ "device": 20,
+ "dtype": 15,
+ "size": 10,
+ "load_id": max_load_id_len,
+ "collection": max_collection_len,
+ }
+
+ # Create the header lines
+ sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
+ dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
+
+ output = "Components:\n" + sep_line
+
+ # Separate components into models and others
+ models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
+ others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
+
+ # Models section
+ if models:
+ output += "Models:\n" + dash_line
+ # Column headers
+ output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
+ output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
+ output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
+ output += dash_line
+
+ # Model entries
+ for name, component in models.items():
+ info = self.get_model_info(name)
+ device_str = format_device(component, info)
+ dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
+ load_id = get_load_id(component)
+
+ # Print first collection on the main line
+ first_collection = component_collections[name][0] if component_collections[name] else "N/A"
+
+ output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
+ output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
+ output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
+
+ # Print additional collections on separate lines if they exist
+ for i in range(1, len(component_collections[name])):
+ collection = component_collections[name][i]
+ output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | "
+ output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
+ output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
+
+ output += dash_line
+
+ # Other components section
+ if others:
+ if models: # Add extra newline if we had models section
+ output += "\n"
+ output += "Other Components:\n" + dash_line
+ # Column headers for other components
+ output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n"
+ output += dash_line
+
+ # Other component entries
+ for name, component in others.items():
+ info = self.get_model_info(name)
+
+ # Print first collection on the main line
+ first_collection = component_collections[name][0] if component_collections[name] else "N/A"
+
+ output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
+
+ # Print additional collections on separate lines if they exist
+ for i in range(1, len(component_collections[name])):
+ collection = component_collections[name][i]
+ output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n"
+
+ output += dash_line
+
+ # Add additional component info
+ output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
+ for name in self.components:
+ info = self.get_model_info(name)
+ if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
+ output += f"\n{name}:\n"
+ if info.get("adapters") is not None:
+ output += f" Adapters: {info['adapters']}\n"
+ if info.get("ip_adapter"):
+ output += " IP-Adapter: Enabled\n"
+
+ return output
+
+ def get_one(
+ self,
+ component_id: Optional[str] = None,
+ name: Optional[str] = None,
+ collection: Optional[str] = None,
+ load_id: Optional[str] = None,
+ ) -> Any:
+ """
+ Get a single component by either:
+ - searching name (pattern matching), collection, or load_id.
+ - passing in a component_id
+ Raises an error if multiple components match or none are found.
+
+ Args:
+ component_id (Optional[str]): Optional component ID to get
+ name (Optional[str]): Component name or pattern
+ collection (Optional[str]): Optional collection to filter by
+ load_id (Optional[str]): Optional load_id to filter by
+
+ Returns:
+ A single component
+
+ Raises:
+ ValueError: If no components match or multiple components match
+ """
+
+ if component_id is not None and (name is not None or collection is not None or load_id is not None):
+ raise ValueError("If searching by component_id, do not pass name, collection, or load_id")
+
+ # search by component_id
+ if component_id is not None:
+ if component_id not in self.components:
+ raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
+ return self.components[component_id]
+ # search with name/collection/load_id
+ results = self.search_components(name, collection, load_id)
+
+ if not results:
+ raise ValueError(f"No components found matching '{name}'")
+
+ if len(results) > 1:
+ raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
+
+ return next(iter(results.values()))
+
+ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str] = None):
+ """
+ Get component IDs by a list of names, optionally filtered by collection.
+
+ Args:
+ names (Union[str, List[str]]): List of component names
+ collection (Optional[str]): Optional collection to filter by
+
+ Returns:
+ List[str]: List of component IDs
+ """
+ ids = set()
+ if not isinstance(names, list):
+ names = [names]
+ for name in names:
+ ids.update(self._lookup_ids(name=name, collection=collection))
+ return list(ids)
+
+ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional[bool] = True):
+ """
+ Get components by a list of IDs.
+
+ Args:
+ ids (List[str]):
+ List of component IDs
+ return_dict_with_names (Optional[bool]):
+ Whether to return a dictionary with component names as keys:
+
+ Returns:
+ Dict[str, Any]: Dictionary of components.
+ - If return_dict_with_names=True, keys are component names.
+ - If return_dict_with_names=False, keys are component IDs.
+
+ Raises:
+ ValueError: If duplicate component names are found in the search results when return_dict_with_names=True
+ """
+ components = {id: self.components[id] for id in ids}
+
+ if return_dict_with_names:
+ dict_to_return = {}
+ for comp_id, comp in components.items():
+ comp_name = self._id_to_name(comp_id)
+ if comp_name in dict_to_return:
+ raise ValueError(
+ f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys"
+ )
+ dict_to_return[comp_name] = comp
+ return dict_to_return
+ else:
+ return components
+
+ def get_components_by_names(self, names: List[str], collection: Optional[str] = None):
+ """
+ Get components by a list of names, optionally filtered by collection.
+
+ Args:
+ names (List[str]): List of component names
+ collection (Optional[str]): Optional collection to filter by
+
+ Returns:
+ Dict[str, Any]: Dictionary of components with component names as keys
+
+ Raises:
+ ValueError: If duplicate component names are found in the search results
+ """
+ ids = self.get_ids(names, collection)
+ return self.get_components_by_ids(ids)
diff --git a/src/diffusers/modular_pipelines/flux/__init__.py b/src/diffusers/modular_pipelines/flux/__init__.py
new file mode 100644
index 0000000000..2891edf790
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/__init__.py
@@ -0,0 +1,66 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["FluxTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "FluxAutoBeforeDenoiseStep",
+ "FluxAutoBlocks",
+ "FluxAutoBlocks",
+ "FluxAutoDecodeStep",
+ "FluxAutoDenoiseStep",
+ ]
+ _import_structure["modular_pipeline"] = ["FluxModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import FluxTextEncoderStep
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ FluxAutoBeforeDenoiseStep,
+ FluxAutoBlocks,
+ FluxAutoDecodeStep,
+ FluxAutoDenoiseStep,
+ )
+ from .modular_pipeline import FluxModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py
new file mode 100644
index 0000000000..ffc77bb24f
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/before_denoise.py
@@ -0,0 +1,420 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+
+def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+
+class FluxInputStep(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Input processing step that:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
+ "have a final batch_size of batch_size * num_images_per_prompt."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ # TODO: support negative embeddings?
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]:
+ raise ValueError(
+ "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`"
+ f" {block_state.pooled_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # TODO: consider adding negative embeddings?
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxSetTimestepsStep(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("guidance_scale", default=3.5),
+ InputParam("latents", type_hint=torch.Tensor),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+ scheduler = components.scheduler
+
+ latents = block_state.latents
+ image_seq_len = latents.shape[1]
+
+ num_inference_steps = block_state.num_inference_steps
+ sigmas = block_state.sigmas
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(scheduler.config, "use_flow_sigmas") and scheduler.config.use_flow_sigmas:
+ sigmas = None
+ block_state.sigmas = sigmas
+ mu = calculate_shift(
+ image_seq_len,
+ scheduler.config.get("base_image_seq_len", 256),
+ scheduler.config.get("max_image_seq_len", 4096),
+ scheduler.config.get("base_shift", 0.5),
+ scheduler.config.get("max_shift", 1.15),
+ )
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ scheduler, block_state.num_inference_steps, block_state.device, sigmas=block_state.sigmas, mu=mu
+ )
+ if components.transformer.config.guidance_embeds:
+ guidance = torch.full([1], block_state.guidance_scale, device=block_state.device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+ block_state.guidance = guidance
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class FluxPrepareLatentsStep(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_images_per_prompt", type_hint=int, default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ ),
+ OutputParam(
+ "latent_image_ids",
+ type_hint=torch.Tensor,
+ description="IDs computed from the image sequence needed for RoPE",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % (components.vae_scale_factor * 2) != 0) or (
+ block_state.width is not None and block_state.width % (components.vae_scale_factor * 2) != 0
+ ):
+ logger.warning(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
+ )
+
+ @staticmethod
+ def prepare_latents(
+ comp,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ ):
+ # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over
+ # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were
+ # to go with the "# Copied from ..." approach. Or maybe there's a way?
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (comp.vae_scale_factor * 2))
+ width = 2 * (int(width) // (comp.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if latents is not None:
+ latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+ return latents.to(device=device, dtype=dtype), latent_image_ids
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ return latents, latent_image_ids
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+ block_state.device = components._execution_device
+ block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this?
+ block_state.num_channels_latents = components.num_channels_latents
+
+ self.check_inputs(components, block_state)
+
+ block_state.latents, block_state.latent_image_ids = self.prepare_latents(
+ components,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ block_state.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/decoders.py b/src/diffusers/modular_pipelines/flux/decoders.py
new file mode 100644
index 0000000000..8d561d38c6
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/decoders.py
@@ -0,0 +1,114 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKL
+from ...utils import logging
+from ...video_processor import VaeImageProcessor
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+
+class FluxDecodeStep(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 16}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("output_type", default="pil"),
+ InputParam("height", default=1024),
+ InputParam("width", default=1024),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], torch.Tensor, np.ndarray],
+ description="The generated images, can be a list of PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ vae = components.vae
+
+ if not block_state.output_type == "latent":
+ latents = block_state.latents
+ latents = _unpack_latents(latents, block_state.height, block_state.width, components.vae_scale_factor)
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
+ block_state.images = vae.decode(latents, return_dict=False)[0]
+ block_state.images = components.image_processor.postprocess(
+ block_state.images, output_type=block_state.output_type
+ )
+ else:
+ block_state.images = block_state.latents
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py
new file mode 100644
index 0000000000..c4619c17fb
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/denoise.py
@@ -0,0 +1,230 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple
+
+import torch
+
+from ...models import FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ PipelineBlock,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class FluxLoopDenoiser(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("transformer", FluxTransformer2DModel)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [InputParam("joint_attention_kwargs")]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "guidance",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Guidance scale as a tensor",
+ ),
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Prompt embeddings",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pooled prompt embeddings",
+ ),
+ InputParam(
+ "text_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from text sequence needed for RoPE",
+ ),
+ InputParam(
+ "latent_image_ids",
+ required=True,
+ type_hint=torch.Tensor,
+ description="IDs computed from image sequence needed for RoPE",
+ ),
+ # TODO: guidance
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ noise_pred = components.transformer(
+ hidden_states=block_state.latents,
+ timestep=t.flatten() / 1000,
+ guidance=block_state.guidance,
+ encoder_hidden_states=block_state.prompt_embeds,
+ pooled_projections=block_state.pooled_prompt_embeds,
+ joint_attention_kwargs=block_state.joint_attention_kwargs,
+ txt_ids=block_state.text_ids,
+ img_ids=block_state.latent_image_ids,
+ return_dict=False,
+ )[0]
+ block_state.noise_pred = noise_pred
+
+ return components, block_state
+
+
+class FluxLoopAfterDenoiser(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `FluxDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return []
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [InputParam("generator")]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # Perform scheduler step using the predicted output
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
+ ComponentSpec("transformer", FluxTransformer2DModel),
+ ]
+
+ @property
+ def loop_intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ components.scheduler.set_begin_index(0)
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class FluxDenoiseStep(FluxDenoiseLoopWrapper):
+ block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser]
+ block_names = ["denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `FluxLoopDenoiser`\n"
+ " - `FluxLoopAfterDenoiser`\n"
+ "This block supports text2image tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py
new file mode 100644
index 0000000000..9bf2f54eec
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/encoders.py
@@ -0,0 +1,306 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import List, Optional, Union
+
+import regex as re
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
+
+from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
+from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import FluxModularPipeline
+
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class FluxTextEncoderStep(PipelineBlock):
+ model_name = "flux"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", CLIPTextModel),
+ ComponentSpec("tokenizer", CLIPTokenizer),
+ ComponentSpec("text_encoder_2", T5EncoderModel),
+ ComponentSpec("tokenizer_2", T5TokenizerFast),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("prompt_2"),
+ InputParam("joint_attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "text_ids",
+ type_hint=torch.Tensor,
+ description="ids from the text sequence for RoPE",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ for prompt in [block_state.prompt, block_state.prompt_2]:
+ if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}")
+
+ @staticmethod
+ def _get_t5_prompt_embeds(
+ components,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int,
+ max_sequence_length: int,
+ device: torch.device,
+ ):
+ dtype = components.text_encoder_2.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
+
+ text_inputs = components.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+
+ untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ @staticmethod
+ def _get_clip_prompt_embeds(
+ components,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int,
+ device: torch.device,
+ ):
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
+
+ text_inputs = components.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=components.tokenizer.model_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ tokenizer_max_length = components.tokenizer.model_max_length
+ untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or components._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
+ components._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if components.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(components.text_encoder, lora_scale)
+ if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(components.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds(
+ components,
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
+ components,
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if components.text_encoder is not None:
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder_2, lora_scale)
+
+ dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ @torch.no_grad()
+ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ block_state.text_encoder_lora_scale = (
+ block_state.joint_attention_kwargs.get("scale", None)
+ if block_state.joint_attention_kwargs is not None
+ else None
+ )
+ (block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
+ components,
+ prompt=block_state.prompt,
+ prompt_2=None,
+ prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ device=block_state.device,
+ num_images_per_prompt=1, # hardcoded for now.
+ lora_scale=block_state.text_encoder_lora_scale,
+ )
+
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py
new file mode 100644
index 0000000000..b170673037
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py
@@ -0,0 +1,125 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep
+from .decoders import FluxDecodeStep
+from .denoise import FluxDenoiseStep
+from .encoders import FluxTextEncoderStep
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# before_denoise: text2vid
+class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ FluxInputStep,
+ FluxPrepareLatentsStep,
+ FluxSetTimestepsStep,
+ ]
+ block_names = ["input", "prepare_latents", "set_timesteps"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `FluxPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `FluxSetTimestepsStep` is used to set the timesteps\n"
+ )
+
+
+# before_denoise: all task (text2vid,)
+class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxBeforeDenoiseStep]
+ block_names = ["text2image"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2image.\n"
+ + " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ )
+
+
+# denoise: text2image
+class FluxAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [FluxDenoiseStep]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2image tasks."
+ " - `FluxDenoiseStep` (denoise) for text2image tasks."
+ )
+
+
+# decode: all task (text2img, img2img, inpainting)
+class FluxAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [FluxDecodeStep]
+ block_names = ["non-inpaint"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self):
+ return "Decode step that decode the denoised latents into videos outputs.\n - `FluxDecodeStep`"
+
+
+# text2image
+class FluxAutoBlocks(SequentialPipelineBlocks):
+ block_classes = [FluxTextEncoderStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep, FluxAutoDecodeStep]
+ block_names = ["text_encoder", "before_denoise", "denoise", "decoder"]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image using Flux.\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep),
+ ("input", FluxInputStep),
+ ("prepare_latents", FluxPrepareLatentsStep),
+ # Setting it after preparation of latents because we rely on `latents`
+ # to calculate `img_seq_len` for `shift`.
+ ("set_timesteps", FluxSetTimestepsStep),
+ ("denoise", FluxDenoiseStep),
+ ("decode", FluxDecodeStep),
+ ]
+)
+
+
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", FluxTextEncoderStep),
+ ("before_denoise", FluxAutoBeforeDenoiseStep),
+ ("denoise", FluxAutoDenoiseStep),
+ ("decode", FluxAutoDecodeStep),
+ ]
+)
+
+
+ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "auto": AUTO_BLOCKS}
diff --git a/src/diffusers/modular_pipelines/flux/modular_pipeline.py b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
new file mode 100644
index 0000000000..3cd5df0c70
--- /dev/null
+++ b/src/diffusers/modular_pipelines/flux/modular_pipeline.py
@@ -0,0 +1,59 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...loaders import FluxLoraLoaderMixin
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin):
+ """
+ A ModularPipeline for Flux.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ return 128
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if getattr(self, "vae", None) is not None:
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ return vae_scale_factor
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if getattr(self, "transformer", None):
+ num_channels_latents = self.transformer.config.in_channels // 4
+ return num_channels_latents
diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py
new file mode 100644
index 0000000000..0ef1d59f4d
--- /dev/null
+++ b/src/diffusers/modular_pipelines/modular_pipeline.py
@@ -0,0 +1,2841 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import inspect
+import os
+import traceback
+import warnings
+from collections import OrderedDict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from huggingface_hub import create_repo
+from huggingface_hub.utils import validate_hf_hub_args
+from tqdm.auto import tqdm
+from typing_extensions import Self
+
+from ..configuration_utils import ConfigMixin, FrozenDict
+from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
+from ..utils import (
+ PushToHubMixin,
+ is_accelerate_available,
+ logging,
+)
+from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
+from ..utils.hub_utils import load_or_create_model_card, populate_model_card
+from .components_manager import ComponentsManager
+from .modular_pipeline_utils import (
+ ComponentSpec,
+ ConfigSpec,
+ InputParam,
+ InsertableDict,
+ OutputParam,
+ format_components,
+ format_configs,
+ format_inputs_short,
+ format_intermediates_short,
+ make_doc_string,
+)
+
+
+if is_accelerate_available():
+ import accelerate
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+MODULAR_PIPELINE_MAPPING = OrderedDict(
+ [
+ ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"),
+ ("wan", "WanModularPipeline"),
+ ("flux", "FluxModularPipeline"),
+ ]
+)
+
+MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict(
+ [
+ ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"),
+ ("WanModularPipeline", "WanAutoBlocks"),
+ ("FluxModularPipeline", "FluxAutoBlocks"),
+ ]
+)
+
+
+@dataclass
+class PipelineState:
+ """
+ [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
+ """
+
+ inputs: Dict[str, Any] = field(default_factory=dict)
+ intermediates: Dict[str, Any] = field(default_factory=dict)
+ input_kwargs: Dict[str, List[str]] = field(default_factory=dict)
+ intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict)
+
+ def set_input(self, key: str, value: Any, kwargs_type: str = None):
+ """
+ Add an input to the immutable pipeline state, i.e, pipeline_state.inputs.
+
+ The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
+ set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
+ pipeline block has "guider_kwargs" in its expected_inputs list.
+
+ Args:
+ key (str): The key for the input
+ value (Any): The input value
+ kwargs_type (str): The kwargs_type with which the input is associated
+ """
+ self.inputs[key] = value
+ if kwargs_type is not None:
+ if kwargs_type not in self.input_kwargs:
+ self.input_kwargs[kwargs_type] = [key]
+ else:
+ self.input_kwargs[kwargs_type].append(key)
+
+ def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
+ """
+ Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates.
+
+ The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
+ if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
+ automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
+
+ Args:
+ key (str): The key for the intermediate value
+ value (Any): The intermediate value
+ kwargs_type (str): The kwargs_type with which the intermediate value is associated
+ """
+ self.intermediates[key] = value
+ if kwargs_type is not None:
+ if kwargs_type not in self.intermediate_kwargs:
+ self.intermediate_kwargs[kwargs_type] = [key]
+ else:
+ self.intermediate_kwargs[kwargs_type].append(key)
+
+ def get_input(self, key: str, default: Any = None) -> Any:
+ """
+ Get an input from the pipeline state.
+
+ Args:
+ key (str): The key for the input
+ default (Any): The default value to return if the input is not found
+
+ Returns:
+ Any: The input value
+ """
+ value = self.inputs.get(key, default)
+ if value is not None:
+ return deepcopy(value)
+
+ def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
+ """
+ Get multiple inputs from the pipeline state.
+
+ Args:
+ keys (List[str]): The keys for the inputs
+ default (Any): The default value to return if the input is not found
+
+ Returns:
+ Dict[str, Any]: Dictionary of inputs with matching keys
+ """
+ return {key: self.inputs.get(key, default) for key in keys}
+
+ def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
+ """
+ Get all inputs with matching kwargs_type.
+
+ Args:
+ kwargs_type (str): The kwargs_type to filter by
+
+ Returns:
+ Dict[str, Any]: Dictionary of inputs with matching kwargs_type
+ """
+ input_names = self.input_kwargs.get(kwargs_type, [])
+ return self.get_inputs(input_names)
+
+ def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
+ """
+ Get all intermediates with matching kwargs_type.
+
+ Args:
+ kwargs_type (str): The kwargs_type to filter by
+
+ Returns:
+ Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
+ """
+ intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
+ return self.get_intermediates(intermediate_names)
+
+ def get_intermediate(self, key: str, default: Any = None) -> Any:
+ """
+ Get an intermediate value from the pipeline state.
+
+ Args:
+ key (str): The key for the intermediate value
+ default (Any): The default value to return if the intermediate value is not found
+
+ Returns:
+ Any: The intermediate value
+ """
+ return self.intermediates.get(key, default)
+
+ def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
+ """
+ Get multiple intermediate values from the pipeline state.
+
+ Args:
+ keys (List[str]): The keys for the intermediate values
+ default (Any): The default value to return if the intermediate value is not found
+
+ Returns:
+ Dict[str, Any]: Dictionary of intermediate values with matching keys
+ """
+ return {key: self.intermediates.get(key, default) for key in keys}
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Convert PipelineState to a dictionary.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing all attributes of the PipelineState
+ """
+ return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates}
+
+ def __repr__(self):
+ def format_value(v):
+ if hasattr(v, "shape") and hasattr(v, "dtype"):
+ return f"Tensor(dtype={v.dtype}, shape={v.shape})"
+ elif isinstance(v, list) and len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ return f"[Tensor(dtype={v[0].dtype}, shape={v[0].shape}), ...]"
+ else:
+ return repr(v)
+
+ inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items())
+ intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items())
+
+ # Format input_kwargs and intermediate_kwargs
+ input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
+ intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
+
+ return (
+ f"PipelineState(\n"
+ f" inputs={{\n{inputs}\n }},\n"
+ f" intermediates={{\n{intermediates}\n }},\n"
+ f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
+ f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
+ f")"
+ )
+
+
+@dataclass
+class BlockState:
+ """
+ Container for block state data with attribute access and formatted representation.
+ """
+
+ def __init__(self, **kwargs):
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def __getitem__(self, key: str):
+ # allows block_state["foo"]
+ return getattr(self, key, None)
+
+ def __setitem__(self, key: str, value: Any):
+ # allows block_state["foo"] = "bar"
+ setattr(self, key, value)
+
+ def as_dict(self):
+ """
+ Convert BlockState to a dictionary.
+
+ Returns:
+ Dict[str, Any]: Dictionary containing all attributes of the BlockState
+ """
+ return dict(self.__dict__.items())
+
+ def __repr__(self):
+ def format_value(v):
+ # Handle tensors directly
+ if hasattr(v, "shape") and hasattr(v, "dtype"):
+ return f"Tensor(dtype={v.dtype}, shape={v.shape})"
+
+ # Handle lists of tensors
+ elif isinstance(v, list):
+ if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ shapes = [t.shape for t in v]
+ return f"List[{len(v)}] of Tensors with shapes {shapes}"
+ return repr(v)
+
+ # Handle tuples of tensors
+ elif isinstance(v, tuple):
+ if len(v) > 0 and hasattr(v[0], "shape") and hasattr(v[0], "dtype"):
+ shapes = [t.shape for t in v]
+ return f"Tuple[{len(v)}] of Tensors with shapes {shapes}"
+ return repr(v)
+
+ # Handle dicts with tensor values
+ elif isinstance(v, dict):
+ formatted_dict = {}
+ for k, val in v.items():
+ if hasattr(val, "shape") and hasattr(val, "dtype"):
+ formatted_dict[k] = f"Tensor(shape={val.shape}, dtype={val.dtype})"
+ elif (
+ isinstance(val, list)
+ and len(val) > 0
+ and hasattr(val[0], "shape")
+ and hasattr(val[0], "dtype")
+ ):
+ shapes = [t.shape for t in val]
+ formatted_dict[k] = f"List[{len(val)}] of Tensors with shapes {shapes}"
+ else:
+ formatted_dict[k] = repr(val)
+ return formatted_dict
+
+ # Default case
+ return repr(v)
+
+ attributes = "\n".join(f" {k}: {format_value(v)}" for k, v in self.__dict__.items())
+ return f"BlockState(\n{attributes}\n)"
+
+
+class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
+ """
+ Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
+ LoopSequentialPipelineBlocks
+
+ [`ModularPipelineBlocks`] provides method to load and save the defination of pipeline blocks.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ config_name = "config.json"
+ model_name = None
+
+ @classmethod
+ def _get_signature_keys(cls, obj):
+ parameters = inspect.signature(obj.__init__).parameters
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
+ expected_modules = set(required_parameters.keys()) - {"self"}
+
+ return expected_modules, optional_parameters
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: Optional[bool] = None,
+ **kwargs,
+ ):
+ hub_kwargs_names = [
+ "cache_dir",
+ "force_download",
+ "local_files_only",
+ "proxies",
+ "resume_download",
+ "revision",
+ "subfolder",
+ "token",
+ ]
+ hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
+
+ config = cls.load_config(pretrained_model_name_or_path)
+ has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
+ trust_remote_code = resolve_trust_remote_code(
+ trust_remote_code, pretrained_model_name_or_path, has_remote_code
+ )
+ if not (has_remote_code and trust_remote_code):
+ raise ValueError(
+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
+ )
+
+ class_ref = config["auto_map"][cls.__name__]
+ module_file, class_name = class_ref.split(".")
+ module_file = module_file + ".py"
+ block_cls = get_class_from_dynamic_module(
+ pretrained_model_name_or_path,
+ module_file=module_file,
+ class_name=class_name,
+ **hub_kwargs,
+ **kwargs,
+ )
+ expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
+ block_kwargs = {
+ name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
+ }
+
+ return block_cls(**block_kwargs)
+
+ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
+ # TODO: factor out this logic.
+ cls_name = self.__class__.__name__
+
+ full_mod = type(self).__module__
+ module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
+ parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
+ auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
+
+ self.register_to_config(auto_map=auto_map)
+ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
+ config = dict(self.config)
+ self._internal_dict = FrozenDict(config)
+
+ def init_pipeline(
+ self,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ ) -> "ModularPipeline":
+ """
+ create a ModularPipeline, optionally accept modular_repo to load from hub.
+ """
+ pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__)
+ diffusers_module = importlib.import_module("diffusers")
+ pipeline_class = getattr(diffusers_module, pipeline_class_name)
+
+ modular_pipeline = pipeline_class(
+ blocks=deepcopy(self),
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ components_manager=components_manager,
+ collection=collection,
+ )
+ return modular_pipeline
+
+ @staticmethod
+ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
+ """
+ Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
+ current default value is None and new default value is not None. Warns if multiple non-None default values
+ exist for the same input.
+
+ Args:
+ named_input_lists: List of tuples containing (block_name, input_param_list) pairs
+
+ Returns:
+ List[InputParam]: Combined list of unique InputParam objects
+ """
+ combined_dict = {} # name -> InputParam
+ value_sources = {} # name -> block_name
+
+ for block_name, inputs in named_input_lists:
+ for input_param in inputs:
+ if input_param.name is None and input_param.kwargs_type is not None:
+ input_name = "*_" + input_param.kwargs_type
+ else:
+ input_name = input_param.name
+ if input_name in combined_dict:
+ current_param = combined_dict[input_name]
+ if (
+ current_param.default is not None
+ and input_param.default is not None
+ and current_param.default != input_param.default
+ ):
+ warnings.warn(
+ f"Multiple different default values found for input '{input_name}': "
+ f"{current_param.default} (from block '{value_sources[input_name]}') and "
+ f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
+ )
+ if current_param.default is None and input_param.default is not None:
+ combined_dict[input_name] = input_param
+ value_sources[input_name] = block_name
+ else:
+ combined_dict[input_name] = input_param
+ value_sources[input_name] = block_name
+
+ return list(combined_dict.values())
+
+ @staticmethod
+ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> List[OutputParam]:
+ """
+ Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
+ occurrence of each output name.
+
+ Args:
+ named_output_lists: List of tuples containing (block_name, output_param_list) pairs
+
+ Returns:
+ List[OutputParam]: Combined list of unique OutputParam objects
+ """
+ combined_dict = {} # name -> OutputParam
+
+ for block_name, outputs in named_output_lists:
+ for output_param in outputs:
+ if (output_param.name not in combined_dict) or (
+ combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
+ ):
+ combined_dict[output_param.name] = output_param
+
+ return list(combined_dict.values())
+
+
+class PipelineBlock(ModularPipelineBlocks):
+ """
+ A Pipeline Block is the basic building block of a Modular Pipeline.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Args:
+ description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
+ expected_components (List[ComponentSpec], optional):
+ A list of components that are expected to be used in the block, defaults to []. To override, define as a
+ property in subclasses.
+ expected_configs (List[ConfigSpec], optional):
+ A list of configs that are expected to be used in the block, defaults to []. To override, define as a
+ property in subclasses.
+ inputs (List[InputParam], optional):
+ A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
+ property in subclasses.
+ intermediate_inputs (List[InputParam], optional):
+ A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
+ define as a property in subclasses.
+ intermediate_outputs (List[OutputParam], optional):
+ A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
+ define as a property in subclasses.
+ outputs (List[OutputParam], optional):
+ A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
+ property in subclasses.
+ required_inputs (List[str], optional):
+ A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
+ a property in subclasses.
+ required_intermediate_inputs (List[str], optional):
+ A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
+ override, define as a property in subclasses.
+ required_intermediate_outputs (List[str], optional):
+ A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
+ override, define as a property in subclasses.
+ """
+
+ model_name = None
+
+ def __init__(self):
+ self.sub_blocks = InsertableDict()
+
+ @property
+ def description(self) -> str:
+ """Description of the block. Must be implemented by subclasses."""
+ # raise NotImplementedError("description method must be implemented in subclasses")
+ return "TODO: add a description"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ """List of input parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ """List of intermediate input parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
+
+ def _get_outputs(self):
+ return self.intermediate_outputs
+
+ # YiYi TODO: is it too easy for user to unintentionally override these properties?
+ # Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
+ @property
+ def outputs(self) -> List[OutputParam]:
+ return self._get_outputs()
+
+ def _get_required_inputs(self):
+ input_names = []
+ for input_param in self.inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+ return input_names
+
+ @property
+ def required_inputs(self) -> List[str]:
+ return self._get_required_inputs()
+
+ def _get_required_intermediate_inputs(self):
+ input_names = []
+ for input_param in self.intermediate_inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+ return input_names
+
+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
+ # intermediate_inputs is by default required, unless you manually handle it inside the block
+ @property
+ def required_intermediate_inputs(self) -> List[str]:
+ return self._get_required_intermediate_inputs()
+
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ raise NotImplementedError("__call__ method must be implemented in subclasses")
+
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - use format_components with add_empty_lines=False
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+ components = " " + components_str.replace("\n", "\n ")
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+ configs = " " + configs_str.replace("\n", "\n ")
+
+ # Inputs section
+ inputs_str = format_inputs_short(self.inputs)
+ inputs = "Inputs:\n " + inputs_str
+
+ # Intermediates section
+ intermediates_str = format_intermediates_short(
+ self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
+ )
+ intermediates = f"Intermediates:\n{intermediates_str}"
+
+ return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.intermediate_inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+ # YiYi TODO: input and inteermediate inputs with same name? should warn?
+ def get_block_state(self, state: PipelineState) -> dict:
+ """Get all inputs and intermediates in one dictionary"""
+ data = {}
+
+ # Check inputs
+ for input_param in self.inputs:
+ if input_param.name:
+ value = state.get_input(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required input '{input_param.name}' is missing")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all inputs with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
+ if inputs_kwargs:
+ for k, v in inputs_kwargs.items():
+ if v is not None:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+
+ # Check intermediates
+ for input_param in self.intermediate_inputs:
+ if input_param.name:
+ value = state.get_intermediate(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all intermediates with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
+ if intermediate_kwargs:
+ for k, v in intermediate_kwargs.items():
+ if v is not None:
+ if k not in data:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+ return BlockState(**data)
+
+ def set_block_state(self, state: PipelineState, block_state: BlockState):
+ for output_param in self.intermediate_outputs:
+ if not hasattr(block_state, output_param.name):
+ raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
+ param = getattr(block_state, output_param.name)
+ state.set_intermediate(output_param.name, param, output_param.kwargs_type)
+
+ for input_param in self.intermediate_inputs:
+ if hasattr(block_state, input_param.name):
+ param = getattr(block_state, input_param.name)
+ # Only add if the value is different from what's in the state
+ current_value = state.get_intermediate(input_param.name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set_intermediate(input_param.name, param, input_param.kwargs_type)
+
+ for input_param in self.intermediate_inputs:
+ if input_param.name and hasattr(block_state, input_param.name):
+ param = getattr(block_state, input_param.name)
+ # Only add if the value is different from what's in the state
+ current_value = state.get_intermediate(input_param.name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set_intermediate(input_param.name, param, input_param.kwargs_type)
+ elif input_param.kwargs_type:
+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
+ # we need to first find out which inputs are and loop through them.
+ intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
+ for param_name, current_value in intermediate_kwargs.items():
+ param = getattr(block_state, param_name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set_intermediate(param_name, param, input_param.kwargs_type)
+
+
+class AutoPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline Blocks that automatically selects a block to run based on the inputs.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ block_trigger_inputs: List of input names that trigger specific blocks, with None for default
+ """
+
+ block_classes = []
+ block_names = []
+ block_trigger_inputs = []
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block_cls in zip(self.block_names, self.block_classes):
+ sub_blocks[block_name] = block_cls()
+ self.sub_blocks = sub_blocks
+ if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
+ raise ValueError(
+ f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
+ )
+ default_blocks = [t for t in self.block_trigger_inputs if t is None]
+ # can only have 1 or 0 default block, and has to put in the last
+ # the order of blocks matters here because the first block with matching trigger will be dispatched
+ # e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
+ # as long as mask is provided, it is inpaint; if only image is provided, it is img2img
+ if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
+ raise ValueError(
+ f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
+ "in block_trigger_inputs."
+ )
+
+ # Map trigger inputs to block objects
+ self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
+ self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
+ self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
+
+ @property
+ def model_name(self):
+ return next(iter(self.sub_blocks.values())).model_name
+
+ @property
+ def description(self):
+ return ""
+
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ @property
+ def required_inputs(self) -> List[str]:
+ if None not in self.block_trigger_inputs:
+ return []
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_all = set(getattr(first_block, "required_inputs", set()))
+
+ # Intersect with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_all.intersection_update(block_required)
+
+ return list(required_by_all)
+
+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
+ # intermediate_inputs is by default required, unless you manually handle it inside the block
+ @property
+ def required_intermediate_inputs(self) -> List[str]:
+ if None not in self.block_trigger_inputs:
+ return []
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
+
+ # Intersect with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_intermediate_inputs", set()))
+ required_by_all.intersection_update(block_required)
+
+ return list(required_by_all)
+
+ # YiYi TODO: add test for this
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
+ combined_inputs = self.combine_inputs(*named_inputs)
+ # mark Required inputs only if that input is required by all the blocks
+ for input_param in combined_inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+ return combined_inputs
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()]
+ combined_inputs = self.combine_inputs(*named_inputs)
+ # mark Required inputs only if that input is required by all the blocks
+ for input_param in combined_inputs:
+ if input_param.name in self.required_intermediate_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+ return combined_inputs
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ @property
+ def outputs(self) -> List[str]:
+ named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ @torch.no_grad()
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ # Find default block first (if any)
+
+ block = self.trigger_to_block_map.get(None)
+ for input_name in self.block_trigger_inputs:
+ if input_name is not None and state.get_input(input_name) is not None:
+ block = self.trigger_to_block_map[input_name]
+ break
+ elif input_name is not None and state.get_intermediate(input_name) is not None:
+ block = self.trigger_to_block_map[input_name]
+ break
+
+ if block is None:
+ logger.warning(f"skipping auto block: {self.__class__.__name__}")
+ return pipeline, state
+
+ try:
+ logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
+ return block(pipeline, state)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: {block.__class__.__name__}\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+
+ def _get_trigger_inputs(self):
+ """
+ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
+ block_trigger_inputs values
+ """
+
+ def fn_recursive_get_trigger(blocks):
+ trigger_values = set()
+
+ if blocks is not None:
+ for name, block in blocks.items():
+ # Check if current block has trigger inputs(i.e. auto block)
+ if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
+ # Add all non-None values from the trigger inputs list
+ trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
+
+ # If block has sub_blocks, recursively check them
+ if block.sub_blocks:
+ nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
+ trigger_values.update(nested_triggers)
+
+ return trigger_values
+
+ trigger_inputs = set(self.block_trigger_inputs)
+ trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
+
+ return trigger_inputs
+
+ @property
+ def trigger_inputs(self):
+ return self._get_trigger_inputs()
+
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ if self.trigger_inputs:
+ header += "\n"
+ header += " " + "=" * 100 + "\n"
+ header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
+ header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
+ header += " " + "=" * 100 + "\n\n"
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # Get trigger input for this block
+ trigger = None
+ if hasattr(self, "block_to_trigger_map"):
+ trigger = self.block_to_trigger_map.get(name)
+ # Format the trigger info
+ if trigger is None:
+ trigger_str = "[default]"
+ elif isinstance(trigger, (list, tuple)):
+ trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
+ else:
+ trigger_str = f"[trigger: {trigger}]"
+ # For AutoPipelineBlocks, add bullet points
+ blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
+ else:
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.intermediate_inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+
+class SequentialPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
+ sequence.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ """
+
+ block_classes = []
+ block_names = []
+
+ @property
+ def description(self):
+ return ""
+
+ @property
+ def model_name(self):
+ return next(iter(self.sub_blocks.values())).model_name
+
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ @classmethod
+ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlocks":
+ """Creates a SequentialPipelineBlocks instance from a dictionary of blocks.
+
+ Args:
+ blocks_dict: Dictionary mapping block names to block classes or instances
+
+ Returns:
+ A new SequentialPipelineBlocks instance
+ """
+ instance = cls()
+
+ # Create instances if classes are provided
+ sub_blocks = InsertableDict()
+ for name, block in blocks_dict.items():
+ if inspect.isclass(block):
+ sub_blocks[name] = block()
+ else:
+ sub_blocks[name] = block
+
+ instance.block_classes = [block.__class__ for block in sub_blocks.values()]
+ instance.block_names = list(sub_blocks.keys())
+ instance.sub_blocks = sub_blocks
+ return instance
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block_cls in zip(self.block_names, self.block_classes):
+ sub_blocks[block_name] = block_cls()
+ self.sub_blocks = sub_blocks
+
+ @property
+ def required_inputs(self) -> List[str]:
+ # Get the first block from the dictionary
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_any = set(getattr(first_block, "required_inputs", set()))
+
+ # Union with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_any.update(block_required)
+
+ return list(required_by_any)
+
+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
+ # intermediate_inputs is by default required, unless you manually handle it inside the block
+ @property
+ def required_intermediate_inputs(self) -> List[str]:
+ required_intermediate_inputs = []
+ for input_param in self.intermediate_inputs:
+ if input_param.required:
+ required_intermediate_inputs.append(input_param.name)
+ return required_intermediate_inputs
+
+ # YiYi TODO: add test for this
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return self.get_inputs()
+
+ def get_inputs(self):
+ named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
+ combined_inputs = self.combine_inputs(*named_inputs)
+ # mark Required inputs only if that input is required any of the blocks
+ for input_param in combined_inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+ return combined_inputs
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return self.get_intermediate_inputs()
+
+ def get_intermediate_inputs(self):
+ inputs = []
+ outputs = set()
+ added_inputs = set()
+
+ # Go through all blocks in order
+ for block in self.sub_blocks.values():
+ # Add inputs that aren't in outputs yet
+ for inp in block.intermediate_inputs:
+ if inp.name not in outputs and inp.name not in added_inputs:
+ inputs.append(inp)
+ added_inputs.add(inp.name)
+
+ # Only add outputs if the block cannot be skipped
+ should_add_outputs = True
+ if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ should_add_outputs = False
+
+ if should_add_outputs:
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
+ return inputs
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = []
+ for name, block in self.sub_blocks.items():
+ inp_names = {inp.name for inp in block.intermediate_inputs}
+ # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
+ # filter out them here so they do not end up as intermediate_outputs
+ if name not in inp_names:
+ named_outputs.append((name, block.intermediate_outputs))
+ combined_outputs = self.combine_outputs(*named_outputs)
+ return combined_outputs
+
+ # YiYi TODO: I think we can remove the outputs property
+ @property
+ def outputs(self) -> List[str]:
+ # return next(reversed(self.sub_blocks.values())).intermediate_outputs
+ return self.intermediate_outputs
+
+ @torch.no_grad()
+ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
+ for block_name, block in self.sub_blocks.items():
+ try:
+ pipeline, state = block(pipeline, state)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+ return pipeline, state
+
+ def _get_trigger_inputs(self):
+ """
+ Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
+ block_trigger_inputs values
+ """
+
+ def fn_recursive_get_trigger(blocks):
+ trigger_values = set()
+
+ if blocks is not None:
+ for name, block in blocks.items():
+ # Check if current block has trigger inputs(i.e. auto block)
+ if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
+ # Add all non-None values from the trigger inputs list
+ trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
+
+ # If block has sub_blocks, recursively check them
+ if block.sub_blocks:
+ nested_triggers = fn_recursive_get_trigger(block.sub_blocks)
+ trigger_values.update(nested_triggers)
+
+ return trigger_values
+
+ return fn_recursive_get_trigger(self.sub_blocks)
+
+ @property
+ def trigger_inputs(self):
+ return self._get_trigger_inputs()
+
+ def _traverse_trigger_blocks(self, trigger_inputs):
+ # Convert trigger_inputs to a set for easier manipulation
+ active_triggers = set(trigger_inputs)
+
+ def fn_recursive_traverse(block, block_name, active_triggers):
+ result_blocks = OrderedDict()
+
+ # sequential(include loopsequential) or PipelineBlock
+ if not hasattr(block, "block_trigger_inputs"):
+ if block.sub_blocks:
+ # sequential or LoopSequentialPipelineBlocks (keep traversing)
+ for sub_block_name, sub_block in block.sub_blocks.items():
+ blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
+ blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
+ blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
+ result_blocks.update(blocks_to_update)
+ else:
+ # PipelineBlock
+ result_blocks[block_name] = block
+ # Add this block's output names to active triggers if defined
+ if hasattr(block, "outputs"):
+ active_triggers.update(out.name for out in block.outputs)
+ return result_blocks
+
+ # auto
+ else:
+ # Find first block_trigger_input that matches any value in our active_triggers
+ this_block = None
+ for trigger_input in block.block_trigger_inputs:
+ if trigger_input is not None and trigger_input in active_triggers:
+ this_block = block.trigger_to_block_map[trigger_input]
+ break
+
+ # If no matches found, try to get the default (None) block
+ if this_block is None and None in block.block_trigger_inputs:
+ this_block = block.trigger_to_block_map[None]
+
+ if this_block is not None:
+ # sequential/auto (keep traversing)
+ if this_block.sub_blocks:
+ result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
+ else:
+ # PipelineBlock
+ result_blocks[block_name] = this_block
+ # Add this block's output names to active triggers if defined
+ # YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
+ if hasattr(this_block, "outputs"):
+ active_triggers.update(out.name for out in this_block.outputs)
+
+ return result_blocks
+
+ all_blocks = OrderedDict()
+ for block_name, block in self.sub_blocks.items():
+ blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
+ all_blocks.update(blocks_to_update)
+ return all_blocks
+
+ def get_execution_blocks(self, *trigger_inputs):
+ trigger_inputs_all = self.trigger_inputs
+
+ if trigger_inputs is not None:
+ if not isinstance(trigger_inputs, (list, tuple, set)):
+ trigger_inputs = [trigger_inputs]
+ invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
+ if invalid_inputs:
+ logger.warning(
+ f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
+ )
+ trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
+
+ if trigger_inputs is None:
+ if None in trigger_inputs_all:
+ trigger_inputs = [None]
+ else:
+ trigger_inputs = [trigger_inputs_all[0]]
+ blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
+ return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
+
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ if self.trigger_inputs:
+ header += "\n"
+ header += " " + "=" * 100 + "\n"
+ header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
+ header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
+ # Get first trigger input as example
+ example_input = next(t for t in self.trigger_inputs if t is not None)
+ header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
+ header += " " + "=" * 100 + "\n\n"
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # Get trigger input for this block
+ trigger = None
+ if hasattr(self, "block_to_trigger_map"):
+ trigger = self.block_to_trigger_map.get(name)
+ # Format the trigger info
+ if trigger is None:
+ trigger_str = "[default]"
+ elif isinstance(trigger, (list, tuple)):
+ trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
+ else:
+ trigger_str = f"[trigger: {trigger}]"
+ # For AutoPipelineBlocks, add bullet points
+ blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
+ else:
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.intermediate_inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+
+class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
+ """
+ A Pipeline blocks that combines multiple pipeline block classes into a For Loop. When called, it will call each
+ block in sequence.
+
+ This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipeline blocks (such as loading or saving etc.)
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Attributes:
+ block_classes: List of block classes to be used
+ block_names: List of prefixes for each block
+ """
+
+ model_name = None
+ block_classes = []
+ block_names = []
+
+ @property
+ def description(self) -> str:
+ """Description of the block. Must be implemented by subclasses."""
+ raise NotImplementedError("description method must be implemented in subclasses")
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def loop_expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def loop_inputs(self) -> List[InputParam]:
+ """List of input parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def loop_intermediate_inputs(self) -> List[InputParam]:
+ """List of intermediate input parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def loop_intermediate_outputs(self) -> List[OutputParam]:
+ """List of intermediate output parameters. Must be implemented by subclasses."""
+ return []
+
+ @property
+ def loop_required_inputs(self) -> List[str]:
+ input_names = []
+ for input_param in self.loop_inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+ return input_names
+
+ @property
+ def loop_required_intermediate_inputs(self) -> List[str]:
+ input_names = []
+ for input_param in self.loop_intermediate_inputs:
+ if input_param.required:
+ input_names.append(input_param.name)
+ return input_names
+
+ # modified from SequentialPipelineBlocks to include loop_expected_components
+ @property
+ def expected_components(self):
+ expected_components = []
+ for block in self.sub_blocks.values():
+ for component in block.expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ for component in self.loop_expected_components:
+ if component not in expected_components:
+ expected_components.append(component)
+ return expected_components
+
+ # modified from SequentialPipelineBlocks to include loop_expected_configs
+ @property
+ def expected_configs(self):
+ expected_configs = []
+ for block in self.sub_blocks.values():
+ for config in block.expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ for config in self.loop_expected_configs:
+ if config not in expected_configs:
+ expected_configs.append(config)
+ return expected_configs
+
+ # modified from SequentialPipelineBlocks to include loop_inputs
+ def get_inputs(self):
+ named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
+ named_inputs.append(("loop", self.loop_inputs))
+ combined_inputs = self.combine_inputs(*named_inputs)
+ # mark Required inputs only if that input is required any of the blocks
+ for input_param in combined_inputs:
+ if input_param.name in self.required_inputs:
+ input_param.required = True
+ else:
+ input_param.required = False
+ return combined_inputs
+
+ @property
+ # Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
+ def inputs(self):
+ return self.get_inputs()
+
+ # modified from SequentialPipelineBlocks to include loop_intermediate_inputs
+ @property
+ def intermediate_inputs(self):
+ intermediates = self.get_intermediate_inputs()
+ intermediate_names = [input.name for input in intermediates]
+ for loop_intermediate_input in self.loop_intermediate_inputs:
+ if loop_intermediate_input.name not in intermediate_names:
+ intermediates.append(loop_intermediate_input)
+ return intermediates
+
+ # modified from SequentialPipelineBlocks
+ def get_intermediate_inputs(self):
+ inputs = []
+ outputs = set()
+
+ # Go through all blocks in order
+ for block in self.sub_blocks.values():
+ # Add inputs that aren't in outputs yet
+ inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs)
+
+ # Only add outputs if the block cannot be skipped
+ should_add_outputs = True
+ if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
+ should_add_outputs = False
+
+ if should_add_outputs:
+ # Add this block's outputs
+ block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
+ outputs.update(block_intermediate_outputs)
+ return inputs
+
+ # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
+ @property
+ def required_inputs(self) -> List[str]:
+ # Get the first block from the dictionary
+ first_block = next(iter(self.sub_blocks.values()))
+ required_by_any = set(getattr(first_block, "required_inputs", set()))
+
+ required_by_loop = set(getattr(self, "loop_required_inputs", set()))
+ required_by_any.update(required_by_loop)
+
+ # Union with required inputs from all other blocks
+ for block in list(self.sub_blocks.values())[1:]:
+ block_required = set(getattr(block, "required_inputs", set()))
+ required_by_any.update(block_required)
+
+ return list(required_by_any)
+
+ # YiYi TODO: maybe we do not need this, it is only used in docstring,
+ # intermediate_inputs is by default required, unless you manually handle it inside the block
+ @property
+ def required_intermediate_inputs(self) -> List[str]:
+ required_intermediate_inputs = []
+ for input_param in self.intermediate_inputs:
+ if input_param.required:
+ required_intermediate_inputs.append(input_param.name)
+ for input_param in self.loop_intermediate_inputs:
+ if input_param.required:
+ required_intermediate_inputs.append(input_param.name)
+ return required_intermediate_inputs
+
+ # YiYi TODO: this need to be thought about more
+ # modified from SequentialPipelineBlocks to include loop_intermediate_outputs
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
+ combined_outputs = self.combine_outputs(*named_outputs)
+ for output in self.loop_intermediate_outputs:
+ if output.name not in {output.name for output in combined_outputs}:
+ combined_outputs.append(output)
+ return combined_outputs
+
+ # YiYi TODO: this need to be thought about more
+ @property
+ def outputs(self) -> List[str]:
+ return next(reversed(self.sub_blocks.values())).intermediate_outputs
+
+ def __init__(self):
+ sub_blocks = InsertableDict()
+ for block_name, block_cls in zip(self.block_names, self.block_classes):
+ sub_blocks[block_name] = block_cls()
+ self.sub_blocks = sub_blocks
+
+ @classmethod
+ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
+ """
+ Creates a LoopSequentialPipelineBlocks instance from a dictionary of blocks.
+
+ Args:
+ blocks_dict: Dictionary mapping block names to block instances
+
+ Returns:
+ A new LoopSequentialPipelineBlocks instance
+ """
+ instance = cls()
+
+ # Create instances if classes are provided
+ sub_blocks = InsertableDict()
+ for name, block in blocks_dict.items():
+ if inspect.isclass(block):
+ sub_blocks[name] = block()
+ else:
+ sub_blocks[name] = block
+
+ instance.block_classes = [block.__class__ for block in blocks_dict.values()]
+ instance.block_names = list(blocks_dict.keys())
+ instance.sub_blocks = blocks_dict
+ return instance
+
+ def loop_step(self, components, state: PipelineState, **kwargs):
+ for block_name, block in self.sub_blocks.items():
+ try:
+ components, state = block(components, state, **kwargs)
+ except Exception as e:
+ error_msg = (
+ f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
+ f"Error details: {str(e)}\n"
+ f"Traceback:\n{traceback.format_exc()}"
+ )
+ logger.error(error_msg)
+ raise
+ return components, state
+
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
+
+ def get_block_state(self, state: PipelineState) -> dict:
+ """Get all inputs and intermediates in one dictionary"""
+ data = {}
+
+ # Check inputs
+ for input_param in self.inputs:
+ if input_param.name:
+ value = state.get_input(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required input '{input_param.name}' is missing")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all inputs with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
+ if inputs_kwargs:
+ for k, v in inputs_kwargs.items():
+ if v is not None:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+
+ # Check intermediates
+ for input_param in self.intermediate_inputs:
+ if input_param.name:
+ value = state.get_intermediate(input_param.name)
+ if input_param.required and value is None:
+ raise ValueError(f"Required intermediate input '{input_param.name}' is missing.")
+ elif value is not None or (value is None and input_param.name not in data):
+ data[input_param.name] = value
+ elif input_param.kwargs_type:
+ # if kwargs_type is provided, get all intermediates with matching kwargs_type
+ if input_param.kwargs_type not in data:
+ data[input_param.kwargs_type] = {}
+ intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
+ if intermediate_kwargs:
+ for k, v in intermediate_kwargs.items():
+ if v is not None:
+ if k not in data:
+ data[k] = v
+ data[input_param.kwargs_type][k] = v
+ return BlockState(**data)
+
+ def set_block_state(self, state: PipelineState, block_state: BlockState):
+ for output_param in self.intermediate_outputs:
+ if not hasattr(block_state, output_param.name):
+ raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
+ param = getattr(block_state, output_param.name)
+ state.set_intermediate(output_param.name, param, output_param.kwargs_type)
+
+ for input_param in self.intermediate_inputs:
+ if input_param.name and hasattr(block_state, input_param.name):
+ param = getattr(block_state, input_param.name)
+ # Only add if the value is different from what's in the state
+ current_value = state.get_intermediate(input_param.name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set_intermediate(input_param.name, param, input_param.kwargs_type)
+ elif input_param.kwargs_type:
+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
+ # we need to first find out which inputs are and loop through them.
+ intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
+ for param_name, current_value in intermediate_kwargs.items():
+ if not hasattr(block_state, param_name):
+ continue
+ param = getattr(block_state, param_name)
+ if current_value is not param: # Using identity comparison to check if object was modified
+ state.set_intermediate(param_name, param, input_param.kwargs_type)
+
+ @property
+ def doc(self):
+ return make_doc_string(
+ self.inputs,
+ self.intermediate_inputs,
+ self.outputs,
+ self.description,
+ class_name=self.__class__.__name__,
+ expected_components=self.expected_components,
+ expected_configs=self.expected_configs,
+ )
+
+ # modified from SequentialPipelineBlocks,
+ # (does not need trigger_inputs related part so removed them,
+ # do not need to support auto block for loop blocks)
+ def __repr__(self):
+ class_name = self.__class__.__name__
+ base_class = self.__class__.__bases__[0].__name__
+ header = (
+ f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
+ )
+
+ # Format description with proper indentation
+ desc_lines = self.description.split("\n")
+ desc = []
+ # First line with "Description:" label
+ desc.append(f" Description: {desc_lines[0]}")
+ # Subsequent lines with proper indentation
+ if len(desc_lines) > 1:
+ desc.extend(f" {line}" for line in desc_lines[1:])
+ desc = "\n".join(desc) + "\n"
+
+ # Components section - focus only on expected components
+ expected_components = getattr(self, "expected_components", [])
+ components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
+
+ # Configs section - use format_configs with add_empty_lines=False
+ expected_configs = getattr(self, "expected_configs", [])
+ configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
+
+ # Blocks section - moved to the end with simplified format
+ blocks_str = " Sub-Blocks:\n"
+ for i, (name, block) in enumerate(self.sub_blocks.items()):
+ # For SequentialPipelineBlocks, show execution order
+ blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
+
+ # Add block description
+ desc_lines = block.description.split("\n")
+ indented_desc = desc_lines[0]
+ if len(desc_lines) > 1:
+ indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
+ blocks_str += f" Description: {indented_desc}\n\n"
+
+ # Build the representation with conditional sections
+ result = f"{header}\n{desc}"
+
+ # Only add components section if it has content
+ if components_str.strip():
+ result += f"\n\n{components_str}"
+
+ # Only add configs section if it has content
+ if configs_str.strip():
+ result += f"\n\n{configs_str}"
+
+ # Always add blocks section
+ result += f"\n\n{blocks_str})"
+
+ return result
+
+ @torch.compiler.disable
+ def progress_bar(self, iterable=None, total=None):
+ if not hasattr(self, "_progress_bar_config"):
+ self._progress_bar_config = {}
+ elif not isinstance(self._progress_bar_config, dict):
+ raise ValueError(
+ f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
+ )
+
+ if iterable is not None:
+ return tqdm(iterable, **self._progress_bar_config)
+ elif total is not None:
+ return tqdm(total=total, **self._progress_bar_config)
+ else:
+ raise ValueError("Either `total` or `iterable` has to be defined.")
+
+ def set_progress_bar_config(self, **kwargs):
+ self._progress_bar_config = kwargs
+
+
+# YiYi TODO:
+# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)
+# 2. do we need ConfigSpec? the are basically just key/val kwargs
+# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components()
+class ModularPipeline(ConfigMixin, PushToHubMixin):
+ """
+ Base class for all Modular pipelines.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+
+ Args:
+ blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
+ """
+
+ config_name = "modular_model_index.json"
+ hf_device_map = None
+
+ # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name
+ def __init__(
+ self,
+ blocks: Optional[ModularPipelineBlocks] = None,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Initialize a ModularPipeline instance.
+
+ This method sets up the pipeline by:
+ - creating default pipeline blocks if not provided
+ - gather component and config specifications based on the pipeline blocks's requirement (e.g.
+ expected_components, expected_configs)
+ - update the loading specs of from_pretrained components based on the modular_model_index.json file from
+ huggingface hub if `pretrained_model_name_or_path` is provided
+ - create defaultfrom_config components and register everything
+
+ Args:
+ blocks: `ModularPipelineBlocks` instance. If None, will attempt to load
+ default blocks based on the pipeline class name.
+ pretrained_model_name_or_path: Path to a pretrained pipeline configuration. If provided,
+ will load component specs (only for from_pretrained components) and config values from the saved
+ modular_model_index.json file.
+ components_manager:
+ Optional ComponentsManager for managing multiple component cross different pipelines and apply
+ offloading strategies.
+ collection: Optional collection name for organizing components in the ComponentsManager.
+ **kwargs: Additional arguments passed to `load_config()` when loading pretrained configuration.
+
+ Examples:
+ ```python
+ # Initialize with custom blocks
+ pipeline = ModularPipeline(blocks=my_custom_blocks)
+
+ # Initialize from pretrained configuration
+ pipeline = ModularPipeline(blocks=my_blocks, pretrained_model_name_or_path="my-repo/modular-pipeline")
+
+ # Initialize with components manager
+ pipeline = ModularPipeline(
+ blocks=my_blocks, components_manager=ComponentsManager(), collection="my_collection"
+ )
+ ```
+
+ Notes:
+ - If blocks is None, the method will try to find default blocks based on the pipeline class name
+ - Components with default_creation_method="from_config" are created immediately, its specs are not included
+ in config dict and will not be saved in `modular_model_index.json`
+ - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with
+ `load_default_components()`/`load_components()`
+ - The pipeline's config dict is populated with component specs (only for from_pretrained components) and
+ config values, which will be saved as `modular_model_index.json` during `save_pretrained`
+ - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
+ `_blocks_class_name` in the config dict
+ """
+ if blocks is None:
+ blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__)
+ if blocks_class_name is not None:
+ diffusers_module = importlib.import_module("diffusers")
+ blocks_class = getattr(diffusers_module, blocks_class_name)
+ blocks = blocks_class()
+ else:
+ logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
+
+ self.blocks = blocks
+ self._components_manager = components_manager
+ self._collection = collection
+ self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
+ self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
+
+ # update component_specs and config_specs from modular_repo
+ if pretrained_model_name_or_path is not None:
+ config_dict = self.load_config(pretrained_model_name_or_path, **kwargs)
+
+ for name, value in config_dict.items():
+ # all the components in modular_model_index.json are from_pretrained components
+ if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
+ library, class_name, component_spec_dict = value
+ component_spec = self._dict_to_component_spec(name, component_spec_dict)
+ component_spec.default_creation_method = "from_pretrained"
+ self._component_specs[name] = component_spec
+
+ elif name in self._config_specs:
+ self._config_specs[name].default = value
+
+ register_components_dict = {}
+ for name, component_spec in self._component_specs.items():
+ if component_spec.default_creation_method == "from_config":
+ component = component_spec.create()
+ else:
+ component = None
+ register_components_dict[name] = component
+ self.register_components(**register_components_dict)
+
+ default_configs = {}
+ for name, config_spec in self._config_specs.items():
+ default_configs[name] = config_spec.default
+ self.register_to_config(**default_configs)
+
+ self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
+
+ @property
+ def default_call_parameters(self) -> Dict[str, Any]:
+ """
+ Returns:
+ - Dictionary mapping input names to their default values
+ """
+ params = {}
+ for input_param in self.blocks.inputs:
+ params[input_param.name] = input_param.default
+ return params
+
+ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
+ """
+ Execute the pipeline by running the pipeline blocks with the given inputs.
+
+ Args:
+ state (`PipelineState`, optional):
+ PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
+ created based on the user inputs and the pipeline blocks's requirement.
+ output (`str` or `List[str]`, optional):
+ Optional specification of what to return:
+ - None: Returns the complete `PipelineState` with all inputs and intermediates (default)
+ - str: Returns a specific intermediate value from the state (e.g. `output="image"`)
+ - List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
+ "latents"]`)
+
+
+ Examples:
+ ```python
+ # Get complete pipeline state
+ state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
+ print(state.intermediates) # All intermediate outputs
+
+ # Get specific output
+ image = pipeline(prompt="A beautiful sunset", output="image")
+
+ # Get multiple specific outputs
+ results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
+ image, latents = results["image"], results["latents"]
+
+ # Continue from previous state
+ state = pipeline(prompt="A beautiful sunset")
+ new_state = pipeline(state=state, output="image") # Continue processing
+ ```
+
+ Returns:
+ - If `output` is None: Complete `PipelineState` containing all inputs and intermediates
+ - If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
+ - If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
+ `output=["image", "latents"]`)
+ """
+ if state is None:
+ state = PipelineState()
+
+ # Make a copy of the input kwargs
+ passed_kwargs = kwargs.copy()
+
+ # Add inputs to state, using defaults if not provided in the kwargs or the state
+ # if same input already in the state, will override it if provided in the kwargs
+
+ intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
+ for expected_input_param in self.blocks.inputs:
+ name = expected_input_param.name
+ default = expected_input_param.default
+ kwargs_type = expected_input_param.kwargs_type
+ if name in passed_kwargs:
+ if name not in intermediate_inputs:
+ state.set_input(name, passed_kwargs.pop(name), kwargs_type)
+ else:
+ state.set_input(name, passed_kwargs[name], kwargs_type)
+ elif name not in state.inputs:
+ state.set_input(name, default, kwargs_type)
+
+ for expected_intermediate_param in self.blocks.intermediate_inputs:
+ name = expected_intermediate_param.name
+ kwargs_type = expected_intermediate_param.kwargs_type
+ if name in passed_kwargs:
+ state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type)
+
+ # Warn about unexpected inputs
+ if len(passed_kwargs) > 0:
+ warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
+ # Run the pipeline
+ with torch.no_grad():
+ try:
+ _, state = self.blocks(self, state)
+ except Exception:
+ error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
+ logger.error(error_msg)
+ raise
+
+ if output is None:
+ return state
+
+ elif isinstance(output, str):
+ return state.get_intermediate(output)
+
+ elif isinstance(output, (list, tuple)):
+ return state.get_intermediates(output)
+ else:
+ raise ValueError(f"Output '{output}' is not a valid output type")
+
+ def load_default_components(self, **kwargs):
+ """
+ Load from_pretrained components using the loading specs in the config dict.
+
+ Args:
+ **kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc.
+ """
+ names = [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_pretrained"
+ ]
+ self.load_components(names=names, **kwargs)
+
+ @classmethod
+ @validate_hf_hub_args
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ trust_remote_code: Optional[bool] = None,
+ components_manager: Optional[ComponentsManager] = None,
+ collection: Optional[str] = None,
+ **kwargs,
+ ):
+ """
+ Load a ModularPipeline from a huggingface hub repo.
+
+ Args:
+ pretrained_model_name_or_path (`str` or `os.PathLike`, optional):
+ Path to a pretrained pipeline configuration. If provided, will load component specs (only for
+ from_pretrained components) and config values from the modular_model_index.json file.
+ trust_remote_code (`bool`, optional):
+ Whether to trust remote code when loading the pipeline, need to be set to True if you want to create
+ pipeline blocks based on the custom code in `pretrained_model_name_or_path`
+ components_manager (`ComponentsManager`, optional):
+ ComponentsManager instance for managing multiple component cross different pipelines and apply
+ offloading strategies.
+ collection (`str`, optional):`
+ Collection name for organizing components in the ComponentsManager.
+ """
+ from ..pipelines.pipeline_loading_utils import _get_pipeline_class
+
+ try:
+ blocks = ModularPipelineBlocks.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ except EnvironmentError:
+ blocks = None
+
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ proxies = kwargs.pop("proxies", None)
+ token = kwargs.pop("token", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ revision = kwargs.pop("revision", None)
+
+ load_config_kwargs = {
+ "cache_dir": cache_dir,
+ "force_download": force_download,
+ "proxies": proxies,
+ "token": token,
+ "local_files_only": local_files_only,
+ "revision": revision,
+ }
+
+ try:
+ config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
+ pipeline_class = _get_pipeline_class(cls, config=config_dict)
+ except EnvironmentError:
+ pipeline_class = cls
+ pretrained_model_name_or_path = None
+
+ pipeline = pipeline_class(
+ blocks=blocks,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ components_manager=components_manager,
+ collection=collection,
+ **kwargs,
+ )
+ return pipeline
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ """
+ Save the pipeline to a directory. It does not save components, you need to save them separately.
+
+ Args:
+ save_directory (`str` or `os.PathLike`):
+ Path to the directory where the pipeline will be saved.
+ push_to_hub (`bool`, optional):
+ Whether to push the pipeline to the huggingface hub.
+ **kwargs: Additional arguments passed to `save_config()` method
+ """
+ if push_to_hub:
+ commit_message = kwargs.pop("commit_message", None)
+ private = kwargs.pop("private", None)
+ create_pr = kwargs.pop("create_pr", False)
+ token = kwargs.pop("token", None)
+ repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
+ repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
+
+ # Create a new empty model card and eventually tag it
+ model_card = load_or_create_model_card(repo_id, token=token, is_pipeline=True)
+ model_card = populate_model_card(model_card)
+ model_card.save(os.path.join(save_directory, "README.md"))
+
+ # YiYi TODO: maybe order the json file to make it more readable: configs first, then components
+ self.save_config(save_directory=save_directory)
+
+ if push_to_hub:
+ self._upload_folder(
+ save_directory,
+ repo_id,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ )
+
+ @property
+ def doc(self):
+ """
+ Returns:
+ - The docstring of the pipeline blocks
+ """
+ return self.blocks.doc
+
+ def register_components(self, **kwargs):
+ """
+ Register components with their corresponding specifications.
+
+ This method is responsible for:
+ 1. Sets component objects as attributes on the loader (e.g., self.unet = unet)
+ 2. Updates the config dict, which will be saved as `modular_model_index.json` during `save_pretrained` (only
+ for from_pretrained components)
+ 3. Adds components to the component manager if one is attached (only for from_pretrained components)
+
+ This method is called when:
+ - Components are first initialized in __init__:
+ - from_pretrained components not loaded during __init__ so they are registered as None;
+ - non from_pretrained components are created during __init__ and registered as the object itself
+ - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or
+ loader.update_components(guider=guider_spec)
+ - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g.
+ loader.load_default_components(names=["unet"])
+
+ Args:
+ **kwargs: Keyword arguments where keys are component names and values are component objects.
+ E.g., register_components(unet=unet_model, text_encoder=encoder_model)
+
+ Notes:
+ - When registering None for a component, it sets attribute to None but still syncs specs with the config
+ dict, which will be saved as `modular_model_index.json` during `save_pretrained`
+ - component_specs are updated to match the new component outside of this method, e.g. in
+ `update_components()` method
+ """
+ for name, module in kwargs.items():
+ # current component spec
+ component_spec = self._component_specs.get(name)
+ if component_spec is None:
+ logger.warning(f"ModularPipeline.register_components: skipping unknown component '{name}'")
+ continue
+
+ # check if it is the first time registration, i.e. calling from __init__
+ is_registered = hasattr(self, name)
+ is_from_pretrained = component_spec.default_creation_method == "from_pretrained"
+
+ if module is not None:
+ # actual library and class name of the module
+ library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel")
+ else:
+ # if module is None, e.g. self.register_components(unet=None) during __init__
+ # we do not update the spec,
+ # but we still need to update the modular_model_index.json config based on component spec
+ library, class_name = None, None
+
+ # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config
+ # e.g. {"repo": "stabilityai/stable-diffusion-2-1",
+ # "type_hint": ("diffusers", "UNet2DConditionModel"),
+ # "subfolder": "unet",
+ # "variant": None,
+ # "revision": None}
+ component_spec_dict = self._component_spec_to_dict(component_spec)
+
+ register_dict = {name: (library, class_name, component_spec_dict)}
+
+ # set the component as attribute
+ # if it is not set yet, just set it and skip the process to check and warn below
+ if not is_registered:
+ if is_from_pretrained:
+ self.register_to_config(**register_dict)
+ setattr(self, name, module)
+ if module is not None and is_from_pretrained and self._components_manager is not None:
+ self._components_manager.add(name, module, self._collection)
+ continue
+
+ current_module = getattr(self, name, None)
+ # skip if the component is already registered with the same object
+ if current_module is module:
+ logger.info(
+ f"ModularPipeline.register_components: {name} is already registered with same object, skipping"
+ )
+ continue
+
+ # warn if unregister
+ if current_module is not None and module is None:
+ logger.info(
+ f"ModularPipeline.register_components: setting '{name}' to None "
+ f"(was {current_module.__class__.__name__})"
+ )
+ # same type, new instance → replace but send debug log
+ elif (
+ current_module is not None
+ and module is not None
+ and isinstance(module, current_module.__class__)
+ and current_module != module
+ ):
+ logger.debug(
+ f"ModularPipeline.register_components: replacing existing '{name}' "
+ f"(same type {type(current_module).__name__}, new instance)"
+ )
+
+ # update modular_model_index.json config
+ if is_from_pretrained:
+ self.register_to_config(**register_dict)
+ # finally set models
+ setattr(self, name, module)
+ # add to component manager if one is attached
+ if module is not None and is_from_pretrained and self._components_manager is not None:
+ self._components_manager.add(name, module, self._collection)
+
+ @property
+ def device(self) -> torch.device:
+ r"""
+ Returns:
+ `torch.device`: The torch device on which the pipeline is located.
+ """
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ return module.device
+
+ return torch.device("cpu")
+
+ @property
+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
+ Accelerate's module hooks.
+ """
+ for name, model in self.components.items():
+ if not isinstance(model, torch.nn.Module):
+ continue
+
+ if not hasattr(model, "_hf_hook"):
+ return self.device
+ for module in model.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ r"""
+ Returns:
+ `torch.dtype`: The torch dtype on which the pipeline is located.
+ """
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ for module in modules:
+ return module.dtype
+
+ return torch.float32
+
+ @property
+ def null_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for components that needs to be loaded
+ """
+ return [name for name in self._component_specs.keys() if hasattr(self, name) and getattr(self, name) is None]
+
+ @property
+ def component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for all components
+ """
+ return list(self.components.keys())
+
+ @property
+ def pretrained_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for from_pretrained components
+ """
+ return [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_pretrained"
+ ]
+
+ @property
+ def config_component_names(self) -> List[str]:
+ """
+ Returns:
+ - List of names for from_config components
+ """
+ return [
+ name
+ for name in self._component_specs.keys()
+ if self._component_specs[name].default_creation_method == "from_config"
+ ]
+
+ @property
+ def components(self) -> Dict[str, Any]:
+ """
+ Returns:
+ - Dictionary mapping component names to their objects (include both from_pretrained and from_config
+ components)
+ """
+ # return only components we've actually set as attributes on self
+ return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)}
+
+ def get_component_spec(self, name: str) -> ComponentSpec:
+ """
+ Returns:
+ - a copy of the ComponentSpec object for the given component name
+ """
+ return deepcopy(self._component_specs[name])
+
+ def update_components(self, **kwargs):
+ """
+ Update components and configuration values and specs after the pipeline has been instantiated.
+
+ This method allows you to:
+ 1. Replace existing components with new ones (e.g., updating `self.unet` or `self.text_encoder`)
+ 2. Update configuration values (e.g., changing `self.requires_safety_checker` flag)
+
+ In addition to updating the components and configuration values as pipeline attributes, the method also
+ updates:
+ - the corresponding specs in `_component_specs` and `_config_specs`
+ - the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
+
+ Args:
+ **kwargs: Component objects, ComponentSpec objects, or configuration values to update:
+ - Component objects: Only supports components we can extract specs using
+ `ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
+ ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
+ - ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
+ method to create a new component (e.g., `guider=ComponentSpec(name="guider",
+ type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
+ - Configuration values: Simple values to update configuration settings (e.g.,
+ `requires_safety_checker=False`)
+
+ Raises:
+ ValueError: If a component object is not supported in ComponentSpec.from_component() method:
+ - nn.Module components without a valid `_diffusers_load_id` attribute
+ - Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
+
+ Examples:
+ ```python
+ # Update multiple components at once
+ pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
+
+ # Update configuration values
+ pipeline.update_components(requires_safety_checker=False)
+
+ # Update both components and configs together
+ pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
+
+ # Update with ComponentSpec objects (from_config only)
+ pipeline.update_components(
+ guider=ComponentSpec(
+ name="guider",
+ type_hint=ClassifierFreeGuidance,
+ config={"guidance_scale": 5.0},
+ default_creation_method="from_config",
+ )
+ )
+ ```
+
+ Notes:
+ - Components with trained weights must be created using ComponentSpec.load(). If the component has not been
+ shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
+ - ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
+ - ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
+ update_components()
+ """
+
+ # extract component_specs_updates & config_specs_updates from `specs`
+ passed_component_specs = {
+ k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
+ }
+ passed_components = {
+ k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
+ }
+ passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
+
+ for name, component in passed_components.items():
+ current_component_spec = self._component_specs[name]
+
+ # warn if type changed
+ if current_component_spec.type_hint is not None and not isinstance(
+ component, current_component_spec.type_hint
+ ):
+ logger.warning(
+ f"ModularPipeline.update_components: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
+ )
+ # update _component_specs based on the new component
+ new_component_spec = ComponentSpec.from_component(name, component)
+ if new_component_spec.default_creation_method != current_component_spec.default_creation_method:
+ logger.warning(
+ f"ModularPipeline.update_components: changing the default_creation_method of {name} from {current_component_spec.default_creation_method} to {new_component_spec.default_creation_method}."
+ )
+
+ self._component_specs[name] = new_component_spec
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
+
+ created_components = {}
+ for name, component_spec in passed_component_specs.items():
+ if component_spec.default_creation_method == "from_pretrained":
+ raise ValueError(
+ "ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
+ )
+ created_components[name] = component_spec.create()
+ current_component_spec = self._component_specs[name]
+ # warn if type changed
+ if current_component_spec.type_hint is not None and not isinstance(
+ created_components[name], current_component_spec.type_hint
+ ):
+ logger.warning(
+ f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
+ )
+ # update _component_specs based on the user passed component_spec
+ self._component_specs[name] = component_spec
+ self.register_components(**passed_components, **created_components)
+
+ config_to_register = {}
+ for name, new_value in passed_config_values.items():
+ # e.g. requires_aesthetics_score = False
+ self._config_specs[name].default = new_value
+ config_to_register[name] = new_value
+ self.register_to_config(**config_to_register)
+
+ # YiYi TODO: support map for additional from_pretrained kwargs
+ # YiYi/Dhruv TODO: consolidate load_components and load_default_components?
+ def load_components(self, names: Union[List[str], str], **kwargs):
+ """
+ Load selected components from specs.
+
+ Args:
+ names: List of component names to load; by default will not load any components
+ **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be:
+ - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16
+ - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32}
+ - if potentially override ComponentSpec if passed a different loading field in kwargs, e.g. `repo`,
+ `variant`, `revision`, etc.
+ """
+
+ if isinstance(names, str):
+ names = [names]
+ elif not isinstance(names, list):
+ raise ValueError(f"Invalid type for names: {type(names)}")
+
+ components_to_load = {name for name in names if name in self._component_specs}
+ unknown_names = {name for name in names if name not in self._component_specs}
+ if len(unknown_names) > 0:
+ logger.warning(f"Unknown components will be ignored: {unknown_names}")
+
+ components_to_register = {}
+ for name in components_to_load:
+ spec = self._component_specs[name]
+ component_load_kwargs = {}
+ for key, value in kwargs.items():
+ if not isinstance(value, dict):
+ # if the value is a single value, apply it to all components
+ component_load_kwargs[key] = value
+ else:
+ if name in value:
+ # if it is a dict, check if the component name is in the dict
+ component_load_kwargs[key] = value[name]
+ elif "default" in value:
+ # check if the default is specified
+ component_load_kwargs[key] = value["default"]
+ try:
+ components_to_register[name] = spec.load(**component_load_kwargs)
+ except Exception as e:
+ logger.warning(f"Failed to create component '{name}': {e}")
+
+ # Register all components at once
+ self.register_components(**components_to_register)
+
+ # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active
+ def _maybe_raise_error_if_group_offload_active(
+ self, raise_error: bool = False, module: Optional[torch.nn.Module] = None
+ ) -> bool:
+ from ..hooks.group_offloading import _is_group_offload_enabled
+
+ components = self.components.values() if module is None else [module]
+ components = [component for component in components if isinstance(component, torch.nn.Module)]
+ for component in components:
+ if _is_group_offload_enabled(component):
+ if raise_error:
+ raise ValueError(
+ "You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
+ "with group offloading enabled. This is not supported. Please disable group offloading for "
+ "components of the pipeline to use other offloading methods."
+ )
+ return True
+ return False
+
+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
+ def to(self, *args, **kwargs) -> Self:
+ r"""
+ Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
+ arguments of `self.to(*args, **kwargs).`
+
+
+
+ If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
+ the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
+
+
+
+
+ Here are the ways to call `to`:
+
+ - `to(dtype, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+ - `to(device, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the specified
+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
+ - `to(device=None, dtype=None, silence_dtype_warnings=False) → DiffusionPipeline` to return a pipeline with the
+ specified [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device) and
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+
+ Arguments:
+ dtype (`torch.dtype`, *optional*):
+ Returns a pipeline with the specified
+ [`dtype`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)
+ device (`torch.Device`, *optional*):
+ Returns a pipeline with the specified
+ [`device`](https://pytorch.org/docs/stable/tensor_attributes.html#torch.device)
+ silence_dtype_warnings (`str`, *optional*, defaults to `False`):
+ Whether to omit warnings if the target `dtype` is not compatible with the target `device`.
+
+ Returns:
+ [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
+ """
+ from ..pipelines.pipeline_utils import _check_bnb_status
+ from ..utils import is_accelerate_available, is_accelerate_version, is_hpu_available, is_transformers_version
+
+ dtype = kwargs.pop("dtype", None)
+ device = kwargs.pop("device", None)
+ silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
+
+ dtype_arg = None
+ device_arg = None
+ if len(args) == 1:
+ if isinstance(args[0], torch.dtype):
+ dtype_arg = args[0]
+ else:
+ device_arg = torch.device(args[0]) if args[0] is not None else None
+ elif len(args) == 2:
+ if isinstance(args[0], torch.dtype):
+ raise ValueError(
+ "When passing two arguments, make sure the first corresponds to `device` and the second to `dtype`."
+ )
+ device_arg = torch.device(args[0]) if args[0] is not None else None
+ dtype_arg = args[1]
+ elif len(args) > 2:
+ raise ValueError("Please make sure to pass at most two arguments (`device` and `dtype`) `.to(...)`")
+
+ if dtype is not None and dtype_arg is not None:
+ raise ValueError(
+ "You have passed `dtype` both as an argument and as a keyword argument. Please only pass one of the two."
+ )
+
+ dtype = dtype or dtype_arg
+
+ if device is not None and device_arg is not None:
+ raise ValueError(
+ "You have passed `device` both as an argument and as a keyword argument. Please only pass one of the two."
+ )
+
+ device = device or device_arg
+ device_type = torch.device(device).type if device is not None else None
+ pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
+
+ # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
+ def module_is_sequentially_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
+ return False
+
+ _, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+
+ if is_loaded_in_8bit_bnb:
+ return False
+
+ return hasattr(module, "_hf_hook") and (
+ isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
+ or hasattr(module._hf_hook, "hooks")
+ and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
+ )
+
+ def module_is_offloaded(module):
+ if not is_accelerate_available() or is_accelerate_version("<", "0.17.0.dev0"):
+ return False
+
+ return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
+
+ # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
+ pipeline_is_sequentially_offloaded = any(
+ module_is_sequentially_offloaded(module) for _, module in self.components.items()
+ )
+
+ is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
+ if is_pipeline_device_mapped:
+ raise ValueError(
+ "It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
+ )
+
+ if device_type in ["cuda", "xpu"]:
+ if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
+ raise ValueError(
+ "It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
+ )
+ # PR: https://github.com/huggingface/accelerate/pull/3223/
+ elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
+ raise ValueError(
+ "You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
+ )
+
+ # Display a warning in this case (the operation succeeds but the benefits are lost)
+ pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
+ if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
+ logger.warning(
+ f"It seems like you have activated model offloading by calling `enable_model_cpu_offload`, but are now manually moving the pipeline to GPU. It is strongly recommended against doing so as memory gains from offloading are likely to be lost. Offloading automatically takes care of moving the individual components {', '.join(self.components.keys())} to GPU when needed. To make sure offloading works as expected, you should consider moving the pipeline back to CPU: `pipeline.to('cpu')` or removing the move altogether if you use offloading."
+ )
+
+ # Enable generic support for Intel Gaudi accelerator using GPU/HPU migration
+ if device_type == "hpu" and kwargs.pop("hpu_migration", True) and is_hpu_available():
+ os.environ["PT_HPU_GPU_MIGRATION"] = "1"
+ logger.debug("Environment variable set: PT_HPU_GPU_MIGRATION=1")
+
+ import habana_frameworks.torch # noqa: F401
+
+ # HPU hardware check
+ if not (hasattr(torch, "hpu") and torch.hpu.is_available()):
+ raise ValueError("You are trying to call `.to('hpu')` but HPU device is unavailable.")
+
+ os.environ["PT_HPU_MAX_COMPOUND_OP_SIZE"] = "1"
+ logger.debug("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1")
+
+ modules = self.components.values()
+ modules = [m for m in modules if isinstance(m, torch.nn.Module)]
+
+ is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
+ for module in modules:
+ _, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
+ is_group_offloaded = self._maybe_raise_error_if_group_offload_active(module=module)
+
+ if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
+ )
+
+ if is_loaded_in_8bit_bnb and device is not None:
+ logger.warning(
+ f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
+ )
+
+ # Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
+ # components can be from outside diffusers too, but still have group offloading enabled.
+ if (
+ self._maybe_raise_error_if_group_offload_active(raise_error=False, module=module)
+ and device is not None
+ ):
+ logger.warning(
+ f"The module '{module.__class__.__name__}' is group offloaded and moving it to {device} via `.to()` is not supported."
+ )
+
+ # This can happen for `transformer` models. CPU placement was added in
+ # https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
+ if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
+ module.to(device=device)
+ elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
+ module.to(device, dtype)
+
+ if (
+ module.dtype == torch.float16
+ and str(device) in ["cpu"]
+ and not silence_dtype_warnings
+ and not is_offloaded
+ ):
+ logger.warning(
+ "Pipelines loaded with `dtype=torch.float16` cannot run with `cpu` device. It"
+ " is not recommended to move them to `cpu` as running them will fail. Please make"
+ " sure to use an accelerator to run the pipeline in inference, due to the lack of"
+ " support for`float16` operations on this device in PyTorch. Please, remove the"
+ " `torch_dtype=torch.float16` argument, or use another device for inference."
+ )
+ return self
+
+ @staticmethod
+ def _component_spec_to_dict(component_spec: ComponentSpec) -> Any:
+ """
+ Convert a ComponentSpec into a JSON‐serializable dict for saving as an entry in `modular_model_index.json`. If
+ the `default_creation_method` is not `from_pretrained`, return None.
+
+ This dict contains:
+ - "type_hint": Tuple[str, str]
+ Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
+ - All loading fields defined by `component_spec.loading_fields()`, typically:
+ - "repo": Optional[str]
+ The model repository (e.g., "stabilityai/stable-diffusion-xl").
+ - "subfolder": Optional[str]
+ A subfolder within the repo where this component lives.
+ - "variant": Optional[str]
+ An optional variant identifier for the model.
+ - "revision": Optional[str]
+ A specific git revision (commit hash, tag, or branch).
+ - ... any other loading fields defined on the spec.
+
+ Args:
+ component_spec (ComponentSpec):
+ The spec object describing one pipeline component.
+
+ Returns:
+ Dict[str, Any]: A mapping suitable for JSON serialization.
+
+ Example:
+ >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers import
+ UNet2DConditionModel >>> spec = ComponentSpec(
+ ... name="unet", ... type_hint=UNet2DConditionModel, ... config=None, ... repo="path/to/repo", ...
+ subfolder="subfolder", ... variant=None, ... revision=None, ...
+ default_creation_method="from_pretrained",
+ ... ) >>> ModularPipeline._component_spec_to_dict(spec) {
+ "type_hint": ("diffusers", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": "subfolder",
+ "variant": None, "revision": None,
+ }
+ """
+ if component_spec.default_creation_method != "from_pretrained":
+ return None
+
+ if component_spec.type_hint is not None:
+ lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint)
+ else:
+ lib_name = None
+ cls_name = None
+ load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()}
+ return {
+ "type_hint": (lib_name, cls_name),
+ **load_spec_dict,
+ }
+
+ @staticmethod
+ def _dict_to_component_spec(
+ name: str,
+ spec_dict: Dict[str, Any],
+ ) -> ComponentSpec:
+ """
+ Reconstruct a ComponentSpec from a loading specdict.
+
+ This method converts a dictionary representation back into a ComponentSpec object. The dict should contain:
+ - "type_hint": Tuple[str, str]
+ Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel"))
+ - All loading fields defined by `component_spec.loading_fields()`, typically:
+ - "repo": Optional[str]
+ The model repository (e.g., "stabilityai/stable-diffusion-xl").
+ - "subfolder": Optional[str]
+ A subfolder within the repo where this component lives.
+ - "variant": Optional[str]
+ An optional variant identifier for the model.
+ - "revision": Optional[str]
+ A specific git revision (commit hash, tag, or branch).
+ - ... any other loading fields defined on the spec.
+
+ Args:
+ name (str):
+ The name of the component.
+ specdict (Dict[str, Any]):
+ A dictionary containing the component specification data.
+
+ Returns:
+ ComponentSpec: A reconstructed ComponentSpec object.
+
+ Example:
+ >>> spec_dict = { ... "type_hint": ("diffusers", "UNet2DConditionModel"), ... "repo":
+ "stabilityai/stable-diffusion-xl", ... "subfolder": "unet", ... "variant": None, ... "revision": None, ...
+ } >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) ComponentSpec(
+ name="unet", type_hint=UNet2DConditionModel, config=None, repo="stabilityai/stable-diffusion-xl",
+ subfolder="unet", variant=None, revision=None, default_creation_method="from_pretrained"
+ )
+ """
+ # make a shallow copy so we can pop() safely
+ spec_dict = spec_dict.copy()
+ # pull out and resolve the stored type_hint
+ lib_name, cls_name = spec_dict.pop("type_hint")
+ if lib_name is not None and cls_name is not None:
+ type_hint = simple_get_class_obj(lib_name, cls_name)
+ else:
+ type_hint = None
+
+ # re‐assemble the ComponentSpec
+ return ComponentSpec(
+ name=name,
+ type_hint=type_hint,
+ **spec_dict,
+ )
diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
new file mode 100644
index 0000000000..f2fc015e94
--- /dev/null
+++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py
@@ -0,0 +1,673 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from collections import OrderedDict
+from dataclasses import dataclass, field, fields
+from typing import Any, Dict, List, Literal, Optional, Type, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, FrozenDict
+from ..utils import is_torch_available, logging
+
+
+if is_torch_available():
+ pass
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class InsertableDict(OrderedDict):
+ def insert(self, key, value, index):
+ items = list(self.items())
+
+ # Remove key if it already exists to avoid duplicates
+ items = [(k, v) for k, v in items if k != key]
+
+ # Insert at the specified index
+ items.insert(index, (key, value))
+
+ # Clear and update self
+ self.clear()
+ self.update(items)
+
+ # Return self for method chaining
+ return self
+
+ def __repr__(self):
+ if not self:
+ return "InsertableDict()"
+
+ items = []
+ for i, (key, value) in enumerate(self.items()):
+ if isinstance(value, type):
+ # For classes, show class name and
+ obj_repr = f""
+ else:
+ # For objects (instances) and other types, show class name and module
+ obj_repr = f""
+ items.append(f"{i}: ({repr(key)}, {obj_repr})")
+
+ return "InsertableDict([\n " + ",\n ".join(items) + "\n])"
+
+
+# YiYi TODO:
+# 1. validate the dataclass fields
+# 2. improve the docstring and potentially add a validator for load methods, make sure they are valid inputs to pass to from_pretrained()
+@dataclass
+class ComponentSpec:
+ """Specification for a pipeline component.
+
+ A component can be created in two ways:
+ 1. From scratch using __init__ with a config dict
+ 2. using `from_pretrained`
+
+ Attributes:
+ name: Name of the component
+ type_hint: Type of the component (e.g. UNet2DConditionModel)
+ description: Optional description of the component
+ config: Optional config dict for __init__ creation
+ repo: Optional repo path for from_pretrained creation
+ subfolder: Optional subfolder in repo
+ variant: Optional variant in repo
+ revision: Optional revision in repo
+ default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
+ """
+
+ name: Optional[str] = None
+ type_hint: Optional[Type] = None
+ description: Optional[str] = None
+ config: Optional[FrozenDict] = None
+ # YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
+ repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
+ subfolder: Optional[str] = field(default="", metadata={"loading": True})
+ variant: Optional[str] = field(default=None, metadata={"loading": True})
+ revision: Optional[str] = field(default=None, metadata={"loading": True})
+ default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
+
+ def __hash__(self):
+ """Make ComponentSpec hashable, using load_id as the hash value."""
+ return hash((self.name, self.load_id, self.default_creation_method))
+
+ def __eq__(self, other):
+ """Compare ComponentSpec objects based on name and load_id."""
+ if not isinstance(other, ComponentSpec):
+ return False
+ return (
+ self.name == other.name
+ and self.load_id == other.load_id
+ and self.default_creation_method == other.default_creation_method
+ )
+
+ @classmethod
+ def from_component(cls, name: str, component: Any) -> Any:
+ """Create a ComponentSpec from a Component.
+
+ Currently supports:
+ - Components created with `ComponentSpec.load()` method
+ - Components that are ConfigMixin subclasses but not nn.Modules (e.g. schedulers, guiders)
+
+ Args:
+ name: Name of the component
+ component: Component object to create spec from
+
+ Returns:
+ ComponentSpec object
+
+ Raises:
+ ValueError: If component is not supported (e.g. nn.Module without load_id, non-ConfigMixin)
+ """
+
+ # Check if component was created with ComponentSpec.load()
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ # component has a usable load_id -> from_pretrained, no warning needed
+ default_creation_method = "from_pretrained"
+ else:
+ # Component doesn't have a usable load_id, check if it's a nn.Module
+ if isinstance(component, torch.nn.Module):
+ raise ValueError(
+ "Cannot create ComponentSpec from a nn.Module that was not created with `ComponentSpec.load()` method."
+ )
+ # ConfigMixin objects without weights (e.g. scheduler & guider) can be recreated with from_config
+ elif isinstance(component, ConfigMixin):
+ # warn if component was not created with `ComponentSpec`
+ if not hasattr(component, "_diffusers_load_id"):
+ logger.warning(
+ "Component was not created using `ComponentSpec`, defaulting to `from_config` creation method"
+ )
+ default_creation_method = "from_config"
+ else:
+ # Not a ConfigMixin and not created with `ComponentSpec.load()` method -> throw error
+ raise ValueError(
+ f"Cannot create ComponentSpec from {name}({component.__class__.__name__}). Currently ComponentSpec.from_component() only supports: "
+ f" - components created with `ComponentSpec.load()` method"
+ f" - components that are a subclass of ConfigMixin but not a nn.Module (e.g. guider, scheduler)."
+ )
+
+ type_hint = component.__class__
+
+ if isinstance(component, ConfigMixin) and default_creation_method == "from_config":
+ config = component.config
+ else:
+ config = None
+ if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
+ load_spec = cls.decode_load_id(component._diffusers_load_id)
+ else:
+ load_spec = {}
+
+ return cls(
+ name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec
+ )
+
+ @classmethod
+ def loading_fields(cls) -> List[str]:
+ """
+ Return the names of all loading‐related fields (i.e. those whose field.metadata["loading"] is True).
+ """
+ return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
+
+ @property
+ def load_id(self) -> str:
+ """
+ Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
+ segments).
+ """
+ if self.default_creation_method == "from_config":
+ return "null"
+ parts = [getattr(self, k) for k in self.loading_fields()]
+ parts = ["null" if p is None else p for p in parts]
+ return "|".join(p for p in parts if p)
+
+ @classmethod
+ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
+ """
+ Decode a load_id string back into a dictionary of loading fields and values.
+
+ Args:
+ load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
+ where None values are represented as "null"
+
+ Returns:
+ Dict mapping loading field names to their values. e.g. {
+ "repo": "path/to/repo", "subfolder": "subfolder", "variant": "variant", "revision": "revision"
+ } If a segment value is "null", it's replaced with None. Returns None if load_id is "null" (indicating
+ component not created with `load` method).
+ """
+
+ # Get all loading fields in order
+ loading_fields = cls.loading_fields()
+ result = {f: None for f in loading_fields}
+
+ if load_id == "null":
+ return result
+
+ # Split the load_id
+ parts = load_id.split("|")
+
+ # Map parts to loading fields by position
+ for i, part in enumerate(parts):
+ if i < len(loading_fields):
+ # Convert "null" string back to None
+ result[loading_fields[i]] = None if part == "null" else part
+
+ return result
+
+ # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
+ # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
+ # the config info is lost in the process
+ # remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method
+ def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
+ """Create component using from_config with config."""
+
+ if self.type_hint is None or not isinstance(self.type_hint, type):
+ raise ValueError("`type_hint` is required when using from_config creation method.")
+
+ config = config or self.config or {}
+
+ if issubclass(self.type_hint, ConfigMixin):
+ component = self.type_hint.from_config(config, **kwargs)
+ else:
+ signature_params = inspect.signature(self.type_hint.__init__).parameters
+ init_kwargs = {}
+ for k, v in config.items():
+ if k in signature_params:
+ init_kwargs[k] = v
+ for k, v in kwargs.items():
+ if k in signature_params:
+ init_kwargs[k] = v
+ component = self.type_hint(**init_kwargs)
+
+ component._diffusers_load_id = "null"
+ if hasattr(component, "config"):
+ self.config = component.config
+
+ return component
+
+ # YiYi TODO: add guard for type of model, if it is supported by from_pretrained
+ def load(self, **kwargs) -> Any:
+ """Load component using from_pretrained."""
+
+ # select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
+ passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
+ # merge loading field value in the spec with user passed values to create load_kwargs
+ load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
+ # repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
+ repo = load_kwargs.pop("repo", None)
+ if repo is None:
+ raise ValueError(
+ "`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)"
+ )
+
+ if self.type_hint is None:
+ try:
+ from diffusers import AutoModel
+
+ component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
+ except Exception as e:
+ raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
+ # update type_hint if AutoModel load successfully
+ self.type_hint = component.__class__
+ else:
+ try:
+ component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
+ except Exception as e:
+ raise ValueError(f"Unable to load {self.name} using load method: {e}")
+
+ self.repo = repo
+ for k, v in load_kwargs.items():
+ setattr(self, k, v)
+ component._diffusers_load_id = self.load_id
+
+ return component
+
+
+@dataclass
+class ConfigSpec:
+ """Specification for a pipeline configuration parameter."""
+
+ name: str
+ default: Any
+ description: Optional[str] = None
+
+
+# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
+# however some fields are not relevant for intermediate_inputs
+# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
+# default is not used for intermediate_inputs, we only use default from inputs, so it is ignored if it is set for intermediate_inputs
+# -> should we use different class for inputs and intermediate_inputs?
+@dataclass
+class InputParam:
+ """Specification for an input parameter."""
+
+ name: str = None
+ type_hint: Any = None
+ default: Any = None
+ required: bool = False
+ description: str = ""
+ kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
+
+ def __repr__(self):
+ return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
+
+
+@dataclass
+class OutputParam:
+ """Specification for an output parameter."""
+
+ name: str
+ type_hint: Any = None
+ description: str = ""
+ kwargs_type: str = None # YiYi notes: remove this feature (maybe)
+
+ def __repr__(self):
+ return (
+ f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
+ )
+
+
+def format_inputs_short(inputs):
+ """
+ Format input parameters into a string representation, with required params first followed by optional ones.
+
+ Args:
+ inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
+
+ Returns:
+ str: Formatted string of input parameters
+
+ Example:
+ >>> inputs = [ ... InputParam(name="prompt", required=True), ... InputParam(name="image", required=True), ...
+ InputParam(name="guidance_scale", required=False, default=7.5), ... InputParam(name="num_inference_steps",
+ required=False, default=50) ... ] >>> format_inputs_short(inputs) 'prompt, image, guidance_scale=7.5,
+ num_inference_steps=50'
+ """
+ required_inputs = [param for param in inputs if param.required]
+ optional_inputs = [param for param in inputs if not param.required]
+
+ required_str = ", ".join(param.name for param in required_inputs)
+ optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
+
+ inputs_str = required_str
+ if optional_str:
+ inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
+
+ return inputs_str
+
+
+def format_intermediates_short(intermediate_inputs, required_intermediate_inputs, intermediate_outputs):
+ """
+ Formats intermediate inputs and outputs of a block into a string representation.
+
+ Args:
+ intermediate_inputs: List of intermediate input parameters
+ required_intermediate_inputs: List of required intermediate input names
+ intermediate_outputs: List of intermediate output parameters
+
+ Returns:
+ str: Formatted string like:
+ Intermediates:
+ - inputs: Required(latents), dtype
+ - modified: latents # variables that appear in both inputs and outputs
+ - outputs: images # new outputs only
+ """
+ # Handle inputs
+ input_parts = []
+ for inp in intermediate_inputs:
+ if inp.name in required_intermediate_inputs:
+ input_parts.append(f"Required({inp.name})")
+ else:
+ if inp.name is None and inp.kwargs_type is not None:
+ inp_name = "*_" + inp.kwargs_type
+ else:
+ inp_name = inp.name
+ input_parts.append(inp_name)
+
+ # Handle modified variables (appear in both inputs and outputs)
+ inputs_set = {inp.name for inp in intermediate_inputs}
+ modified_parts = []
+ new_output_parts = []
+
+ for out in intermediate_outputs:
+ if out.name in inputs_set:
+ modified_parts.append(out.name)
+ else:
+ new_output_parts.append(out.name)
+
+ result = []
+ if input_parts:
+ result.append(f" - inputs: {', '.join(input_parts)}")
+ if modified_parts:
+ result.append(f" - modified: {', '.join(modified_parts)}")
+ if new_output_parts:
+ result.append(f" - outputs: {', '.join(new_output_parts)}")
+
+ return "\n".join(result) if result else " (none)"
+
+
+def format_params(params, header="Args", indent_level=4, max_line_length=115):
+ """Format a list of InputParam or OutputParam objects into a readable string representation.
+
+ Args:
+ params: List of InputParam or OutputParam objects to format
+ header: Header text to use (e.g. "Args" or "Returns")
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all parameters
+ """
+ if not params:
+ return ""
+
+ base_indent = " " * indent_level
+ param_indent = " " * (indent_level + 4)
+ desc_indent = " " * (indent_level + 8)
+ formatted_params = []
+
+ def get_type_str(type_hint):
+ if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
+ types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
+ return f"Union[{', '.join(types)}]"
+ return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
+
+ def wrap_text(text, indent, max_length):
+ """Wrap text while preserving markdown links and maintaining indentation."""
+ words = text.split()
+ lines = []
+ current_line = []
+ current_length = 0
+
+ for word in words:
+ word_length = len(word) + (1 if current_line else 0)
+
+ if current_line and current_length + word_length > max_length:
+ lines.append(" ".join(current_line))
+ current_line = [word]
+ current_length = len(word)
+ else:
+ current_line.append(word)
+ current_length += word_length
+
+ if current_line:
+ lines.append(" ".join(current_line))
+
+ return f"\n{indent}".join(lines)
+
+ # Add the header
+ formatted_params.append(f"{base_indent}{header}:")
+
+ for param in params:
+ # Format parameter name and type
+ type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
+ # YiYi Notes: remove this line if we remove kwargs_type
+ name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
+ param_str = f"{param_indent}{name} (`{type_str}`"
+
+ # Add optional tag and default value if parameter is an InputParam and optional
+ if hasattr(param, "required"):
+ if not param.required:
+ param_str += ", *optional*"
+ if param.default is not None:
+ param_str += f", defaults to {param.default}"
+ param_str += "):"
+
+ # Add description on a new line with additional indentation and wrapping
+ if param.description:
+ desc = re.sub(r"\[(.*?)\]\((https?://[^\s\)]+)\)", r"[\1](\2)", param.description)
+ wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
+ param_str += f"\n{desc_indent}{wrapped_desc}"
+
+ formatted_params.append(param_str)
+
+ return "\n\n".join(formatted_params)
+
+
+def format_input_params(input_params, indent_level=4, max_line_length=115):
+ """Format a list of InputParam objects into a readable string representation.
+
+ Args:
+ input_params: List of InputParam objects to format
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all input parameters
+ """
+ return format_params(input_params, "Inputs", indent_level, max_line_length)
+
+
+def format_output_params(output_params, indent_level=4, max_line_length=115):
+ """Format a list of OutputParam objects into a readable string representation.
+
+ Args:
+ output_params: List of OutputParam objects to format
+ indent_level: Number of spaces to indent each parameter line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+
+ Returns:
+ A formatted string representing all output parameters
+ """
+ return format_params(output_params, "Outputs", indent_level, max_line_length)
+
+
+def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
+ """Format a list of ComponentSpec objects into a readable string representation.
+
+ Args:
+ components: List of ComponentSpec objects to format
+ indent_level: Number of spaces to indent each component line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+ add_empty_lines: Whether to add empty lines between components (default: True)
+
+ Returns:
+ A formatted string representing all components
+ """
+ if not components:
+ return ""
+
+ base_indent = " " * indent_level
+ component_indent = " " * (indent_level + 4)
+ formatted_components = []
+
+ # Add the header
+ formatted_components.append(f"{base_indent}Components:")
+ if add_empty_lines:
+ formatted_components.append("")
+
+ # Add each component with optional empty lines between them
+ for i, component in enumerate(components):
+ # Get type name, handling special cases
+ type_name = (
+ component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
+ )
+
+ component_desc = f"{component_indent}{component.name} (`{type_name}`)"
+ if component.description:
+ component_desc += f": {component.description}"
+
+ # Get the loading fields dynamically
+ loading_field_values = []
+ for field_name in component.loading_fields():
+ field_value = getattr(component, field_name)
+ if field_value is not None:
+ loading_field_values.append(f"{field_name}={field_value}")
+
+ # Add loading field information if available
+ if loading_field_values:
+ component_desc += f" [{', '.join(loading_field_values)}]"
+
+ formatted_components.append(component_desc)
+
+ # Add an empty line after each component except the last one
+ if add_empty_lines and i < len(components) - 1:
+ formatted_components.append("")
+
+ return "\n".join(formatted_components)
+
+
+def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
+ """Format a list of ConfigSpec objects into a readable string representation.
+
+ Args:
+ configs: List of ConfigSpec objects to format
+ indent_level: Number of spaces to indent each config line (default: 4)
+ max_line_length: Maximum length for each line before wrapping (default: 115)
+ add_empty_lines: Whether to add empty lines between configs (default: True)
+
+ Returns:
+ A formatted string representing all configs
+ """
+ if not configs:
+ return ""
+
+ base_indent = " " * indent_level
+ config_indent = " " * (indent_level + 4)
+ formatted_configs = []
+
+ # Add the header
+ formatted_configs.append(f"{base_indent}Configs:")
+ if add_empty_lines:
+ formatted_configs.append("")
+
+ # Add each config with optional empty lines between them
+ for i, config in enumerate(configs):
+ config_desc = f"{config_indent}{config.name} (default: {config.default})"
+ if config.description:
+ config_desc += f": {config.description}"
+ formatted_configs.append(config_desc)
+
+ # Add an empty line after each config except the last one
+ if add_empty_lines and i < len(configs) - 1:
+ formatted_configs.append("")
+
+ return "\n".join(formatted_configs)
+
+
+def make_doc_string(
+ inputs,
+ intermediate_inputs,
+ outputs,
+ description="",
+ class_name=None,
+ expected_components=None,
+ expected_configs=None,
+):
+ """
+ Generates a formatted documentation string describing the pipeline block's parameters and structure.
+
+ Args:
+ inputs: List of input parameters
+ intermediate_inputs: List of intermediate input parameters
+ outputs: List of output parameters
+ description (str, *optional*): Description of the block
+ class_name (str, *optional*): Name of the class to include in the documentation
+ expected_components (List[ComponentSpec], *optional*): List of expected components
+ expected_configs (List[ConfigSpec], *optional*): List of expected configurations
+
+ Returns:
+ str: A formatted string containing information about components, configs, call parameters,
+ intermediate inputs/outputs, and final outputs.
+ """
+ output = ""
+
+ # Add class name if provided
+ if class_name:
+ output += f"class {class_name}\n\n"
+
+ # Add description
+ if description:
+ desc_lines = description.strip().split("\n")
+ aligned_desc = "\n".join(" " + line for line in desc_lines)
+ output += aligned_desc + "\n\n"
+
+ # Add components section if provided
+ if expected_components and len(expected_components) > 0:
+ components_str = format_components(expected_components, indent_level=2)
+ output += components_str + "\n\n"
+
+ # Add configs section if provided
+ if expected_configs and len(expected_configs) > 0:
+ configs_str = format_configs(expected_configs, indent_level=2)
+ output += configs_str + "\n\n"
+
+ # Add inputs section
+ output += format_input_params(inputs + intermediate_inputs, indent_level=2)
+
+ # Add outputs section
+ output += "\n\n"
+ output += format_output_params(outputs, indent_level=2)
+
+ return output
diff --git a/src/diffusers/modular_pipelines/node_utils.py b/src/diffusers/modular_pipelines/node_utils.py
new file mode 100644
index 0000000000..fb9a03c755
--- /dev/null
+++ b/src/diffusers/modular_pipelines/node_utils.py
@@ -0,0 +1,665 @@
+import json
+import logging
+import os
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ..configuration_utils import ConfigMixin
+from ..image_processor import PipelineImageInput
+from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
+from .modular_pipeline_utils import InputParam
+
+
+logger = logging.getLogger(__name__)
+
+# YiYi Notes: this is actually for SDXL, put it here for now
+SDXL_INPUTS_SCHEMA = {
+ "prompt": InputParam(
+ "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
+ ),
+ "prompt_2": InputParam(
+ "prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
+ ),
+ "negative_prompt": InputParam(
+ "negative_prompt",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts not to guide the image generation",
+ ),
+ "negative_prompt_2": InputParam(
+ "negative_prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The negative prompt or prompts for text_encoder_2",
+ ),
+ "cross_attention_kwargs": InputParam(
+ "cross_attention_kwargs",
+ type_hint=Optional[dict],
+ description="Kwargs dictionary passed to the AttentionProcessor",
+ ),
+ "clip_skip": InputParam(
+ "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
+ ),
+ "image": InputParam(
+ "image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="The image(s) to modify for img2img or inpainting",
+ ),
+ "mask_image": InputParam(
+ "mask_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Mask image for inpainting, white pixels will be repainted",
+ ),
+ "generator": InputParam(
+ "generator",
+ type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
+ description="Generator(s) for deterministic generation",
+ ),
+ "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
+ "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
+ "num_images_per_prompt": InputParam(
+ "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
+ ),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
+ ),
+ "timesteps": InputParam(
+ "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
+ ),
+ "sigmas": InputParam(
+ "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
+ ),
+ "denoising_end": InputParam(
+ "denoising_end",
+ type_hint=Optional[float],
+ description="Fraction of denoising process to complete before termination",
+ ),
+ # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
+ "strength": InputParam(
+ "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
+ ),
+ "denoising_start": InputParam(
+ "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
+ ),
+ "padding_mask_crop": InputParam(
+ "padding_mask_crop",
+ type_hint=Optional[Tuple[int, int]],
+ description="Size of margin in crop for image and mask",
+ ),
+ "original_size": InputParam(
+ "original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Original size of the image for SDXL's micro-conditioning",
+ ),
+ "target_size": InputParam(
+ "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
+ ),
+ "negative_original_size": InputParam(
+ "negative_original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on image resolution",
+ ),
+ "negative_target_size": InputParam(
+ "negative_target_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on target resolution",
+ ),
+ "crops_coords_top_left": InputParam(
+ "crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Top-left coordinates for SDXL's micro-conditioning",
+ ),
+ "negative_crops_coords_top_left": InputParam(
+ "negative_crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Negative conditioning crop coordinates",
+ ),
+ "aesthetic_score": InputParam(
+ "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
+ ),
+ "negative_aesthetic_score": InputParam(
+ "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
+ ),
+ "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
+ "output_type": InputParam(
+ "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
+ ),
+ "ip_adapter_image": InputParam(
+ "ip_adapter_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Image(s) to be used as IP adapter",
+ ),
+ "control_image": InputParam(
+ "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
+ ),
+ "control_guidance_start": InputParam(
+ "control_guidance_start",
+ type_hint=Union[float, List[float]],
+ default=0.0,
+ description="When ControlNet starts applying",
+ ),
+ "control_guidance_end": InputParam(
+ "control_guidance_end",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="When ControlNet stops applying",
+ ),
+ "controlnet_conditioning_scale": InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="Scale factor for ControlNet outputs",
+ ),
+ "guess_mode": InputParam(
+ "guess_mode",
+ type_hint=bool,
+ default=False,
+ description="Enables ControlNet encoder to recognize input without prompts",
+ ),
+ "control_mode": InputParam(
+ "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
+ ),
+}
+
+SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
+ "prompt_embeds": InputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ required=True,
+ description="Text embeddings used to guide image generation",
+ ),
+ "negative_prompt_embeds": InputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": InputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": InputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
+ "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "preprocess_kwargs": InputParam(
+ "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
+ ),
+ "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
+ ),
+ "latent_timestep": InputParam(
+ "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
+ ),
+ "image_latents": InputParam(
+ "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
+ ),
+ "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
+ "masked_image_latents": InputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "add_time_ids": InputParam(
+ "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
+ ),
+ "negative_add_time_ids": InputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "ip_adapter_embeds": InputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ required=True,
+ description="Generated images",
+ ),
+}
+
+SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
+
+
+DEFAULT_PARAM_MAPS = {
+ "prompt": {
+ "label": "Prompt",
+ "type": "string",
+ "default": "a bear sitting in a chair drinking a milkshake",
+ "display": "textarea",
+ },
+ "negative_prompt": {
+ "label": "Negative Prompt",
+ "type": "string",
+ "default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
+ "display": "textarea",
+ },
+ "num_inference_steps": {
+ "label": "Steps",
+ "type": "int",
+ "default": 25,
+ "min": 1,
+ "max": 1000,
+ },
+ "seed": {
+ "label": "Seed",
+ "type": "int",
+ "default": 0,
+ "min": 0,
+ "display": "random",
+ },
+ "width": {
+ "label": "Width",
+ "type": "int",
+ "display": "text",
+ "default": 1024,
+ "min": 8,
+ "max": 8192,
+ "step": 8,
+ "group": "dimensions",
+ },
+ "height": {
+ "label": "Height",
+ "type": "int",
+ "display": "text",
+ "default": 1024,
+ "min": 8,
+ "max": 8192,
+ "step": 8,
+ "group": "dimensions",
+ },
+ "images": {
+ "label": "Images",
+ "type": "image",
+ "display": "output",
+ },
+ "image": {
+ "label": "Image",
+ "type": "image",
+ "display": "input",
+ },
+}
+
+DEFAULT_TYPE_MAPS = {
+ "int": {
+ "type": "int",
+ "default": 0,
+ "min": 0,
+ },
+ "float": {
+ "type": "float",
+ "default": 0.0,
+ "min": 0.0,
+ },
+ "str": {
+ "type": "string",
+ "default": "",
+ },
+ "bool": {
+ "type": "boolean",
+ "default": False,
+ },
+ "image": {
+ "type": "image",
+ },
+}
+
+DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
+DEFAULT_CATEGORY = "Modular Diffusers"
+DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
+DEFAULT_PARAMS_GROUPS_KEYS = {
+ "text_encoders": ["text_encoder", "tokenizer"],
+ "ip_adapter_embeds": ["ip_adapter_embeds"],
+ "prompt_embeddings": ["prompt_embeds"],
+}
+
+
+def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
+ """
+ Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
+ "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
+ """
+ if name is None:
+ return None
+ for group_name, group_keys in group_params_keys.items():
+ for group_key in group_keys:
+ if group_key in name:
+ return group_name
+ return None
+
+
+class ModularNode(ConfigMixin):
+ """
+ A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
+ around a ModularPipelineBlocks object.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ config_name = "node_config.json"
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: str,
+ trust_remote_code: Optional[bool] = None,
+ **kwargs,
+ ):
+ blocks = ModularPipelineBlocks.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
+ )
+ return cls(blocks, **kwargs)
+
+ def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
+ self.blocks = blocks
+
+ if label is None:
+ label = self.blocks.__class__.__name__
+ # blocks param name -> mellon param name
+ self.name_mapping = {}
+
+ input_params = {}
+ # pass or create a default param dict for each input
+ # e.g. for prompt,
+ # prompt = {
+ # "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
+ # "label": "Prompt",
+ # "type": "string",
+ # "default": "a bear sitting in a chair drinking a milkshake",
+ # "display": "textarea"}
+ # if type is not specified, it'll be a "custom" param of its own type
+ # e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
+ # it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
+ # name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
+ inputs = self.blocks.inputs + self.blocks.intermediate_inputs
+ for inp in inputs:
+ param = kwargs.pop(inp.name, None)
+ if param:
+ # user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
+ input_params[inp.name] = param
+ mellon_name = param.pop("name", inp.name)
+ if mellon_name != inp.name:
+ self.name_mapping[inp.name] = mellon_name
+ continue
+
+ if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
+ continue
+
+ if inp.name in DEFAULT_PARAM_MAPS:
+ # first check if it's in the default param map, if so, directly use that
+ param = DEFAULT_PARAM_MAPS[inp.name].copy()
+ elif get_group_name(inp.name):
+ param = get_group_name(inp.name)
+ if inp.name not in self.name_mapping:
+ self.name_mapping[inp.name] = param
+ else:
+ # if not, check if it's in the SDXL input schema, if so,
+ # 1. use the type hint to determine the type
+ # 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
+ if inp.type_hint is not None:
+ type_str = str(inp.type_hint).lower()
+ else:
+ inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
+ type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
+ for type_key, type_param in DEFAULT_TYPE_MAPS.items():
+ if type_key in type_str:
+ param = type_param.copy()
+ param["label"] = inp.name
+ param["display"] = "input"
+ break
+ else:
+ param = inp.name
+ # add the param dict to the inp_params dict
+ input_params[inp.name] = param
+
+ component_params = {}
+ for comp in self.blocks.expected_components:
+ param = kwargs.pop(comp.name, None)
+ if param:
+ component_params[comp.name] = param
+ mellon_name = param.pop("name", comp.name)
+ if mellon_name != comp.name:
+ self.name_mapping[comp.name] = mellon_name
+ continue
+
+ to_exclude = False
+ for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
+ if exclude_key in comp.name:
+ to_exclude = True
+ break
+ if to_exclude:
+ continue
+
+ if get_group_name(comp.name):
+ param = get_group_name(comp.name)
+ if comp.name not in self.name_mapping:
+ self.name_mapping[comp.name] = param
+ elif comp.name in DEFAULT_MODEL_KEYS:
+ param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
+ else:
+ param = comp.name
+ # add the param dict to the model_params dict
+ component_params[comp.name] = param
+
+ output_params = {}
+ if isinstance(self.blocks, SequentialPipelineBlocks):
+ last_block_name = list(self.blocks.sub_blocks.keys())[-1]
+ outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
+ else:
+ outputs = self.blocks.intermediate_outputs
+
+ for out in outputs:
+ param = kwargs.pop(out.name, None)
+ if param:
+ output_params[out.name] = param
+ mellon_name = param.pop("name", out.name)
+ if mellon_name != out.name:
+ self.name_mapping[out.name] = mellon_name
+ continue
+
+ if out.name in DEFAULT_PARAM_MAPS:
+ param = DEFAULT_PARAM_MAPS[out.name].copy()
+ param["display"] = "output"
+ else:
+ group_name = get_group_name(out.name)
+ if group_name:
+ param = group_name
+ if out.name not in self.name_mapping:
+ self.name_mapping[out.name] = param
+ else:
+ param = out.name
+ # add the param dict to the outputs dict
+ output_params[out.name] = param
+
+ if len(kwargs) > 0:
+ logger.warning(f"Unused kwargs: {kwargs}")
+
+ register_dict = {
+ "category": category,
+ "label": label,
+ "input_params": input_params,
+ "component_params": component_params,
+ "output_params": output_params,
+ "name_mapping": self.name_mapping,
+ }
+ self.register_to_config(**register_dict)
+
+ def setup(self, components_manager, collection=None):
+ self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
+ self._components_manager = components_manager
+
+ @property
+ def mellon_config(self):
+ return self._convert_to_mellon_config()
+
+ def _convert_to_mellon_config(self):
+ node = {}
+ node["label"] = self.config.label
+ node["category"] = self.config.category
+
+ node_param = {}
+ for inp_name, inp_param in self.config.input_params.items():
+ if inp_name in self.name_mapping:
+ mellon_name = self.name_mapping[inp_name]
+ else:
+ mellon_name = inp_name
+ if isinstance(inp_param, str):
+ param = {
+ "label": inp_param,
+ "type": inp_param,
+ "display": "input",
+ }
+ else:
+ param = inp_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
+
+ for comp_name, comp_param in self.config.component_params.items():
+ if comp_name in self.name_mapping:
+ mellon_name = self.name_mapping[comp_name]
+ else:
+ mellon_name = comp_name
+ if isinstance(comp_param, str):
+ param = {
+ "label": comp_param,
+ "type": comp_param,
+ "display": "input",
+ }
+ else:
+ param = comp_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
+
+ for out_name, out_param in self.config.output_params.items():
+ if out_name in self.name_mapping:
+ mellon_name = self.name_mapping[out_name]
+ else:
+ mellon_name = out_name
+ if isinstance(out_param, str):
+ param = {
+ "label": out_param,
+ "type": out_param,
+ "display": "output",
+ }
+ else:
+ param = out_param
+
+ if mellon_name not in node_param:
+ node_param[mellon_name] = param
+ else:
+ logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
+ node["params"] = node_param
+ return node
+
+ def save_mellon_config(self, file_path):
+ """
+ Save the Mellon configuration to a JSON file.
+
+ Args:
+ file_path (str or Path): Path where the JSON file will be saved
+
+ Returns:
+ Path: Path to the saved config file
+ """
+ file_path = Path(file_path)
+
+ # Create directory if it doesn't exist
+ os.makedirs(file_path.parent, exist_ok=True)
+
+ # Create a combined dictionary with module definition and name mapping
+ config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
+
+ # Save the config to file
+ with open(file_path, "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2)
+
+ logger.info(f"Mellon config and name mapping saved to {file_path}")
+
+ return file_path
+
+ @classmethod
+ def load_mellon_config(cls, file_path):
+ """
+ Load a Mellon configuration from a JSON file.
+
+ Args:
+ file_path (str or Path): Path to the JSON file containing Mellon config
+
+ Returns:
+ dict: The loaded combined configuration containing 'module' and 'name_mapping'
+ """
+ file_path = Path(file_path)
+
+ if not file_path.exists():
+ raise FileNotFoundError(f"Config file not found: {file_path}")
+
+ with open(file_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+
+ logger.info(f"Mellon config loaded from {file_path}")
+
+ return config
+
+ def process_inputs(self, **kwargs):
+ params_components = {}
+ for comp_name, comp_param in self.config.component_params.items():
+ logger.debug(f"component: {comp_name}")
+ mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
+ if mellon_comp_name in kwargs:
+ if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
+ comp = kwargs[mellon_comp_name].pop(comp_name)
+ else:
+ comp = kwargs.pop(mellon_comp_name)
+ if comp:
+ params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
+
+ params_run = {}
+ for inp_name, inp_param in self.config.input_params.items():
+ logger.debug(f"input: {inp_name}")
+ mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
+ if mellon_inp_name in kwargs:
+ if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
+ inp = kwargs[mellon_inp_name].pop(inp_name)
+ else:
+ inp = kwargs.pop(mellon_inp_name)
+ if inp is not None:
+ params_run[inp_name] = inp
+
+ return_output_names = list(self.config.output_params.keys())
+
+ return params_components, params_run, return_output_names
+
+ def execute(self, **kwargs):
+ params_components, params_run, return_output_names = self.process_inputs(**kwargs)
+
+ self.pipeline.update_components(**params_components)
+ output = self.pipeline(**params_run, output=return_output_names)
+ return output
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py
new file mode 100644
index 0000000000..59ec46dc6d
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py
@@ -0,0 +1,77 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "CONTROLNET_BLOCKS",
+ "IMAGE2IMAGE_BLOCKS",
+ "INPAINT_BLOCKS",
+ "IP_ADAPTER_BLOCKS",
+ "TEXT2IMAGE_BLOCKS",
+ "StableDiffusionXLAutoBlocks",
+ "StableDiffusionXLAutoControlnetStep",
+ "StableDiffusionXLAutoDecodeStep",
+ "StableDiffusionXLAutoIPAdapterStep",
+ "StableDiffusionXLAutoVaeEncoderStep",
+ ]
+ _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import (
+ StableDiffusionXLTextEncoderStep,
+ )
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ CONTROLNET_BLOCKS,
+ IMAGE2IMAGE_BLOCKS,
+ INPAINT_BLOCKS,
+ IP_ADAPTER_BLOCKS,
+ TEXT2IMAGE_BLOCKS,
+ StableDiffusionXLAutoBlocks,
+ StableDiffusionXLAutoControlnetStep,
+ StableDiffusionXLAutoDecodeStep,
+ StableDiffusionXLAutoIPAdapterStep,
+ StableDiffusionXLAutoVaeEncoderStep,
+ )
+ from .modular_pipeline import StableDiffusionXLModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
new file mode 100644
index 0000000000..c56f4af1b8
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py
@@ -0,0 +1,1929 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, List, Optional, Tuple, Union
+
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, UNet2DConditionModel
+from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor, unwrap_module
+from ..modular_pipeline import (
+ PipelineBlock,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
+# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
+# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
+# configuration of guider is.
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+def prepare_latents_img2img(
+ vae, scheduler, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
+):
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
+ raise ValueError(f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}")
+
+ image = image.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ if image.shape[1] == 4:
+ init_latents = image
+
+ else:
+ latents_mean = latents_std = None
+ if hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(vae.config, "latents_std") and vae.config.latents_std is not None:
+ latents_std = torch.tensor(vae.config.latents_std).view(1, 4, 1, 1)
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ if vae.config.force_upcast:
+ image = image.float()
+ vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ elif isinstance(generator, list):
+ if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
+ image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
+ elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
+ )
+
+ init_latents = [
+ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(batch_size)
+ ]
+ init_latents = torch.cat(init_latents, dim=0)
+ else:
+ init_latents = retrieve_latents(vae.encode(image), generator=generator)
+
+ if vae.config.force_upcast:
+ vae.to(dtype)
+
+ init_latents = init_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=device, dtype=dtype)
+ latents_std = latents_std.to(device=device, dtype=dtype)
+ init_latents = (init_latents - latents_mean) * vae.config.scaling_factor / latents_std
+ else:
+ init_latents = vae.config.scaling_factor * init_latents
+
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ init_latents = torch.cat([init_latents], dim=0)
+
+ if add_noise:
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = scheduler.add_noise(init_latents, noise, timestep)
+
+ latents = init_latents
+
+ return latents
+
+
+class StableDiffusionXLInputStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Input processing step that:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n"
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
+ "have a final batch_size of batch_size * num_images_per_prompt."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_pooled_prompt_embeds",
+ description="Pre-generated negative pooled text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.",
+ ),
+ InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="negative text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="negative pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="image embeddings for IP-Adapter",
+ ),
+ OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="negative image embeddings for IP-Adapter",
+ ),
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {block_state.negative_prompt_embeds.shape}."
+ )
+
+ if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+
+ if block_state.negative_prompt_embeds is not None and block_state.negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if block_state.ip_adapter_embeds is not None and not isinstance(block_state.ip_adapter_embeds, list):
+ raise ValueError("`ip_adapter_embeds` must be a list")
+
+ if block_state.negative_ip_adapter_embeds is not None and not isinstance(
+ block_state.negative_ip_adapter_embeds, list
+ ):
+ raise ValueError("`negative_ip_adapter_embeds` must be a list")
+
+ if block_state.ip_adapter_embeds is not None and block_state.negative_ip_adapter_embeds is not None:
+ for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
+ if ip_adapter_embed.shape != block_state.negative_ip_adapter_embeds[i].shape:
+ raise ValueError(
+ "`ip_adapter_embeds` and `negative_ip_adapter_embeds` must have the same shape when passed directly, but"
+ f" got: `ip_adapter_embeds` {ip_adapter_embed.shape} != `negative_ip_adapter_embeds`"
+ f" {block_state.negative_ip_adapter_embeds[i].shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
+ )
+
+ block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.pooled_prompt_embeds = block_state.pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+
+ if block_state.negative_pooled_prompt_embeds is not None:
+ block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(
+ 1, block_state.num_images_per_prompt, 1
+ )
+ block_state.negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.view(
+ block_state.batch_size * block_state.num_images_per_prompt, -1
+ )
+
+ if block_state.ip_adapter_embeds is not None:
+ for i, ip_adapter_embed in enumerate(block_state.ip_adapter_embeds):
+ block_state.ip_adapter_embeds[i] = torch.cat(
+ [ip_adapter_embed] * block_state.num_images_per_prompt, dim=0
+ )
+
+ if block_state.negative_ip_adapter_embeds is not None:
+ for i, negative_ip_adapter_embed in enumerate(block_state.negative_ip_adapter_embeds):
+ block_state.negative_ip_adapter_embeds[i] = torch.cat(
+ [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step that sets the timesteps for the scheduler and determines the initial noise level (latent_timestep) for image-to-image/inpainting generation.\n"
+ + "The latent_timestep is calculated from the `strength` parameter - higher strength means starting from a noisier version of the input image."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("denoising_end"),
+ InputParam("strength", default=0.3),
+ InputParam("denoising_start"),
+ # YiYi TODO: do we need num_images_per_prompt here?
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ OutputParam(
+ "latent_timestep",
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image generation",
+ ),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps with self->components
+ def get_timesteps(components, num_inference_steps, strength, device, denoising_start=None):
+ # get the original timestep using init_timestep
+ if denoising_start is None:
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+
+ timesteps = components.scheduler.timesteps[t_start * components.scheduler.order :]
+ if hasattr(components.scheduler, "set_begin_index"):
+ components.scheduler.set_begin_index(t_start * components.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ else:
+ # Strength is irrelevant if we directly request a timestep to start at;
+ # that is, strength is determined by the denoising_start instead.
+ discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (denoising_start * components.scheduler.config.num_train_timesteps)
+ )
+ )
+
+ num_inference_steps = (components.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
+ if components.scheduler.order == 2 and num_inference_steps % 2 == 0:
+ # if the scheduler is a 2nd order scheduler we might have to do +1
+ # because `num_inference_steps` might be even given that every timestep
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+ # mean that we cut the timesteps in the middle of the denoising step
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+ num_inference_steps = num_inference_steps + 1
+
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
+ t_start = len(components.scheduler.timesteps) - num_inference_steps
+ timesteps = components.scheduler.timesteps[t_start:]
+ if hasattr(components.scheduler, "set_begin_index"):
+ components.scheduler.set_begin_index(t_start)
+ return timesteps, num_inference_steps
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ block_state.device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ def denoising_value_valid(dnv):
+ return isinstance(dnv, float) and 0 < dnv < 1
+
+ block_state.timesteps, block_state.num_inference_steps = self.get_timesteps(
+ components,
+ block_state.num_inference_steps,
+ block_state.strength,
+ block_state.device,
+ denoising_start=block_state.denoising_start
+ if denoising_value_valid(block_state.denoising_start)
+ else None,
+ )
+ block_state.latent_timestep = block_state.timesteps[:1].repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+
+ if (
+ block_state.denoising_end is not None
+ and isinstance(block_state.denoising_end, float)
+ and block_state.denoising_end > 0
+ and block_state.denoising_end < 1
+ ):
+ block_state.discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
+ )
+ )
+ block_state.num_inference_steps = len(
+ list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
+ )
+ block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLSetTimestepsStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ InputParam("denoising_end"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ block_state.device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ if (
+ block_state.denoising_end is not None
+ and isinstance(block_state.denoising_end, float)
+ and block_state.denoising_end > 0
+ and block_state.denoising_end < 1
+ ):
+ block_state.discrete_timestep_cutoff = int(
+ round(
+ components.scheduler.config.num_train_timesteps
+ - (block_state.denoising_end * components.scheduler.config.num_train_timesteps)
+ )
+ )
+ block_state.num_inference_steps = len(
+ list(filter(lambda ts: ts >= block_state.discrete_timestep_cutoff, block_state.timesteps))
+ )
+ block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps]
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the latents for the inpainting process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("denoising_start"),
+ InputParam(
+ "strength",
+ default=0.9999,
+ description="Conceptually, indicates how much to transform the reference `image` (the masked portion of image for inpainting). Must be between 0 and 1. `image` "
+ "will be used as a starting point, adding more noise to it the larger the `strength`. The number of "
+ "denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will "
+ "be maximum and the denoising process will run for the full number of iterations specified in "
+ "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
+ "`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
+ ),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "latent_timestep",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "mask",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The mask for the inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "masked_image_latents",
+ type_hint=torch.Tensor,
+ description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ ),
+ OutputParam(
+ "noise",
+ type_hint=torch.Tensor,
+ description="The noise added to the image latents, used for inpainting generation",
+ ),
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self->components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ @staticmethod
+ def _encode_vae_image(components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_latents adding components as first argument
+ def prepare_latents_inpaint(
+ self,
+ components,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ add_noise=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // components.vae_scale_factor,
+ int(width) // components.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if image.shape[1] == 4:
+ image_latents = image.to(device=device, dtype=dtype)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+ elif return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(components, image=image, generator=generator)
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+ if latents is None and add_noise:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = noise if is_strength_max else components.scheduler.add_noise(image_latents, noise, timestep)
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = latents * components.scheduler.init_noise_sigma if is_strength_max else latents
+ elif add_noise:
+ noise = latents.to(device)
+ latents = noise * components.scheduler.init_noise_sigma
+ else:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = image_latents.to(device)
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
+ # do not accept do_classifier_free_guidance
+ def prepare_mask_latents(
+ self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+
+ block_state.is_strength_max = block_state.strength == 1.0
+
+ # for non-inpainting specific unet, we do not need masked_image_latents
+ if hasattr(components, "unet") and components.unet is not None:
+ if components.unet.config.in_channels == 4:
+ block_state.masked_image_latents = None
+
+ block_state.add_noise = True if block_state.denoising_start is None else False
+
+ block_state.height = block_state.image_latents.shape[-2] * components.vae_scale_factor
+ block_state.width = block_state.image_latents.shape[-1] * components.vae_scale_factor
+
+ block_state.latents, block_state.noise = self.prepare_latents_inpaint(
+ components,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ components.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ image=block_state.image_latents,
+ timestep=block_state.latent_timestep,
+ is_strength_max=block_state.is_strength_max,
+ add_noise=block_state.add_noise,
+ return_noise=True,
+ return_image_latents=False,
+ )
+
+ # 7. Prepare mask latent variables
+ block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
+ components,
+ block_state.mask,
+ block_state.masked_image_latents,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the latents for the image-to-image generation process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("denoising_start"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "latent_timestep",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "image_latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+ block_state.add_noise = True if block_state.denoising_start is None else False
+ if block_state.latents is None:
+ block_state.latents = prepare_latents_img2img(
+ components.vae,
+ components.scheduler,
+ block_state.image_latents,
+ block_state.latent_timestep,
+ block_state.batch_size,
+ block_state.num_images_per_prompt,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.add_noise,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("vae", AutoencoderKL),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-image generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("latents"),
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (
+ block_state.height is not None
+ and block_state.height % components.vae_scale_factor != 0
+ or block_state.width is not None
+ and block_state.width % components.vae_scale_factor != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor} but are {block_state.height} and {block_state.width}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents with self->comp
+ def prepare_latents(comp, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ int(height) // comp.vae_scale_factor,
+ int(width) // comp.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * comp.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if block_state.dtype is None:
+ block_state.dtype = components.vae.dtype
+
+ block_state.device = components._execution_device
+
+ self.check_inputs(components, block_state)
+
+ block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor
+ block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor
+ block_state.num_channels_latents = components.num_channels_latents
+ block_state.latents = self.prepare_latents(
+ components,
+ block_state.batch_size * block_state.num_images_per_prompt,
+ block_state.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [
+ ConfigSpec("requires_aesthetics_score", False),
+ ]
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the additional conditioning for the image-to-image/inpainting generation process"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("original_size"),
+ InputParam("target_size"),
+ InputParam("negative_original_size"),
+ InputParam("negative_target_size"),
+ InputParam("crops_coords_top_left", default=(0, 0)),
+ InputParam("negative_crops_coords_top_left", default=(0, 0)),
+ InputParam("num_images_per_prompt", default=1),
+ InputParam("aesthetic_score", default=6.0),
+ InputParam("negative_aesthetic_score", default=2.0),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="The time ids to condition the denoising process",
+ ),
+ OutputParam(
+ "negative_add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="The negative time ids to condition the denoising process",
+ ),
+ OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids with self->components
+ def _get_add_time_ids(
+ components,
+ original_size,
+ crops_coords_top_left,
+ target_size,
+ aesthetic_score,
+ negative_aesthetic_score,
+ negative_original_size,
+ negative_crops_coords_top_left,
+ negative_target_size,
+ dtype,
+ text_encoder_projection_dim=None,
+ ):
+ if components.config.requires_aesthetics_score:
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+ add_neg_time_ids = list(
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+ )
+ else:
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+ passed_add_embed_dim = (
+ components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features
+
+ if (
+ expected_add_embed_dim > passed_add_embed_dim
+ and (expected_add_embed_dim - passed_add_embed_dim) == components.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+ )
+ elif (
+ expected_add_embed_dim < passed_add_embed_dim
+ and (passed_add_embed_dim - expected_add_embed_dim) == components.unet.config.addition_time_embed_dim
+ ):
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+ )
+ elif expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+ return add_time_ids, add_neg_time_ids
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.vae_scale_factor = components.vae_scale_factor
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * block_state.vae_scale_factor
+ block_state.width = block_state.width * block_state.vae_scale_factor
+
+ block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
+ block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
+
+ block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
+
+ if block_state.negative_original_size is None:
+ block_state.negative_original_size = block_state.original_size
+ if block_state.negative_target_size is None:
+ block_state.negative_target_size = block_state.target_size
+
+ block_state.add_time_ids, block_state.negative_add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.original_size,
+ block_state.crops_coords_top_left,
+ block_state.target_size,
+ block_state.aesthetic_score,
+ block_state.negative_aesthetic_score,
+ block_state.negative_original_size,
+ block_state.negative_crops_coords_top_left,
+ block_state.negative_target_size,
+ dtype=block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ block_state.add_time_ids = block_state.add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+ block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+
+ # Optionally get Guidance Scale Embedding for LCM
+ block_state.timestep_cond = None
+ if (
+ hasattr(components, "unet")
+ and components.unet is not None
+ and components.unet.config.time_cond_proj_dim is not None
+ ):
+ # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
+ block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+ block_state.timestep_cond = self.get_guidance_scale_embedding(
+ block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
+ ).to(device=block_state.device, dtype=block_state.latents.dtype)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Step that prepares the additional conditioning for the text-to-image generation process"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("original_size"),
+ InputParam("target_size"),
+ InputParam("negative_original_size"),
+ InputParam("negative_target_size"),
+ InputParam("crops_coords_top_left", default=(0, 0)),
+ InputParam("negative_crops_coords_top_left", default=(0, 0)),
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "pooled_prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The pooled prompt embeddings to use for the denoising process (used to determine shapes and dtypes for other additional conditioning inputs). Can be generated in text_encoder step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="The time ids to condition the denoising process",
+ ),
+ OutputParam(
+ "negative_add_time_ids",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="The negative time ids to condition the denoising process",
+ ),
+ OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids with self->components
+ def _get_add_time_ids(
+ components, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
+ ):
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
+
+ passed_add_embed_dim = (
+ components.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+ )
+ expected_add_embed_dim = components.unet.add_embedding.linear_1.in_features
+
+ if expected_add_embed_dim != passed_add_embed_dim:
+ raise ValueError(
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+ )
+
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+ return add_time_ids
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
+ ) -> torch.Tensor:
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ w (`torch.Tensor`):
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
+ embedding_dim (`int`, *optional*, defaults to 512):
+ Dimension of the embeddings to generate.
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
+ Data type of the generated embeddings.
+
+ Returns:
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ block_state.original_size = block_state.original_size or (block_state.height, block_state.width)
+ block_state.target_size = block_state.target_size or (block_state.height, block_state.width)
+
+ block_state.text_encoder_projection_dim = int(block_state.pooled_prompt_embeds.shape[-1])
+
+ block_state.add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.original_size,
+ block_state.crops_coords_top_left,
+ block_state.target_size,
+ block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ if block_state.negative_original_size is not None and block_state.negative_target_size is not None:
+ block_state.negative_add_time_ids = self._get_add_time_ids(
+ components,
+ block_state.negative_original_size,
+ block_state.negative_crops_coords_top_left,
+ block_state.negative_target_size,
+ block_state.pooled_prompt_embeds.dtype,
+ text_encoder_projection_dim=block_state.text_encoder_projection_dim,
+ )
+ else:
+ block_state.negative_add_time_ids = block_state.add_time_ids
+
+ block_state.add_time_ids = block_state.add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+ block_state.negative_add_time_ids = block_state.negative_add_time_ids.repeat(
+ block_state.batch_size * block_state.num_images_per_prompt, 1
+ ).to(device=block_state.device)
+
+ # Optionally get Guidance Scale Embedding for LCM
+ block_state.timestep_cond = None
+ if (
+ hasattr(components, "unet")
+ and components.unet is not None
+ and components.unet.config.time_cond_proj_dim is not None
+ ):
+ # TODO(yiyi, aryan): Ideally, this should be `embedded_guidance_scale` instead of pulling from guider. Guider scales should be different from this!
+ block_state.guidance_scale_tensor = torch.tensor(components.guider.guidance_scale - 1).repeat(
+ block_state.batch_size * block_state.num_images_per_prompt
+ )
+ block_state.timestep_cond = self.get_guidance_scale_embedding(
+ block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
+ ).to(device=block_state.device, dtype=block_state.latents.dtype)
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLControlNetInputStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", ControlNetModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepare inputs for controlnet"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("control_image", required=True),
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("guess_mode", default=False),
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int]],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_cond", type_hint=torch.Tensor, description="The processed control image"),
+ OutputParam(
+ "control_guidance_start", type_hint=List[float], description="The controlnet guidance start values"
+ ),
+ OutputParam(
+ "control_guidance_end", type_hint=List[float], description="The controlnet guidance end values"
+ ),
+ OutputParam(
+ "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"
+ ),
+ OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ # 1. return image without apply any guidance
+ # 2. add crops_coords and resize_mode to preprocess()
+ @staticmethod
+ def prepare_control_image(
+ components,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ crops_coords=None,
+ ):
+ if crops_coords is not None:
+ image = components.control_image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill"
+ ).to(dtype=torch.float32)
+ else:
+ image = components.control_image_processor.preprocess(image, height=height, width=width).to(
+ dtype=torch.float32
+ )
+
+ image_batch_size = image.shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype)
+ return image
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ # (1) prepare controlnet inputs
+ block_state.device = components._execution_device
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ controlnet = unwrap_module(components.controlnet)
+
+ # (1.1)
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+ elif not isinstance(block_state.control_guidance_start, list) and not isinstance(
+ block_state.control_guidance_end, list
+ ):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
+ block_state.control_guidance_start, block_state.control_guidance_end = (
+ mult * [block_state.control_guidance_start],
+ mult * [block_state.control_guidance_end],
+ )
+
+ # (1.2)
+ # controlnet_conditioning_scale (align format)
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(
+ block_state.controlnet_conditioning_scale, float
+ ):
+ block_state.controlnet_conditioning_scale = [block_state.controlnet_conditioning_scale] * len(
+ controlnet.nets
+ )
+
+ # (1.3)
+ # global_pool_conditions
+ block_state.global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
+ # (1.4)
+ # guess_mode
+ block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
+
+ # (1.5)
+ # control_image
+ if isinstance(controlnet, ControlNetModel):
+ block_state.control_image = self.prepare_control_image(
+ components,
+ image=block_state.control_image,
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=block_state.device,
+ dtype=controlnet.dtype,
+ crops_coords=block_state.crops_coords,
+ )
+ elif isinstance(controlnet, MultiControlNetModel):
+ control_images = []
+
+ for control_image_ in block_state.control_image:
+ control_image = self.prepare_control_image(
+ components,
+ image=control_image_,
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=block_state.device,
+ dtype=controlnet.dtype,
+ crops_coords=block_state.crops_coords,
+ )
+
+ control_images.append(control_image)
+
+ block_state.control_image = control_images
+ else:
+ assert False
+
+ # (1.6)
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ keeps = [
+ 1.0 - float(i / len(block_state.timesteps) < s or (i + 1) / len(block_state.timesteps) > e)
+ for s, e in zip(block_state.control_guidance_start, block_state.control_guidance_end)
+ ]
+ block_state.controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
+
+ block_state.controlnet_cond = block_state.control_image
+ block_state.conditioning_scale = block_state.controlnet_conditioning_scale
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("controlnet", ControlNetUnionModel),
+ ComponentSpec(
+ "control_image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"do_convert_rgb": True, "do_normalize": False}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "step that prepares inputs for the ControlNetUnion model"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("control_image", required=True),
+ InputParam("control_mode", required=True),
+ InputParam("control_guidance_start", default=0.0),
+ InputParam("control_guidance_end", default=1.0),
+ InputParam("controlnet_conditioning_scale", default=1.0),
+ InputParam("guess_mode", default=False),
+ InputParam("num_images_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Used to determine the shape of the control images. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
+ ),
+ InputParam(
+ "dtype",
+ required=True,
+ type_hint=torch.dtype,
+ description="The dtype of model tensor inputs. Can be generated in input step.",
+ ),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Needed to determine `controlnet_keep`. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int]],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("controlnet_cond", type_hint=List[torch.Tensor], description="The processed control images"),
+ OutputParam(
+ "control_type_idx",
+ type_hint=List[int],
+ description="The control mode indices",
+ kwargs_type="controlnet_kwargs",
+ ),
+ OutputParam(
+ "control_type",
+ type_hint=torch.Tensor,
+ description="The control type tensor that specifies which control type is active",
+ kwargs_type="controlnet_kwargs",
+ ),
+ OutputParam("control_guidance_start", type_hint=float, description="The controlnet guidance start value"),
+ OutputParam("control_guidance_end", type_hint=float, description="The controlnet guidance end value"),
+ OutputParam(
+ "conditioning_scale", type_hint=List[float], description="The controlnet conditioning scale values"
+ ),
+ OutputParam("guess_mode", type_hint=bool, description="Whether guess mode is used"),
+ OutputParam("controlnet_keep", type_hint=List[float], description="The controlnet keep values"),
+ ]
+
+ # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
+ # 1. return image without apply any guidance
+ # 2. add crops_coords and resize_mode to preprocess()
+ @staticmethod
+ def prepare_control_image(
+ components,
+ image,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ crops_coords=None,
+ ):
+ if crops_coords is not None:
+ image = components.control_image_processor.preprocess(
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill"
+ ).to(dtype=torch.float32)
+ else:
+ image = components.control_image_processor.preprocess(image, height=height, width=width).to(
+ dtype=torch.float32
+ )
+
+ image_batch_size = image.shape[0]
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+ image = image.to(device=device, dtype=dtype)
+ return image
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ controlnet = unwrap_module(components.controlnet)
+
+ device = components._execution_device
+ dtype = block_state.dtype or components.controlnet.dtype
+
+ block_state.height, block_state.width = block_state.latents.shape[-2:]
+ block_state.height = block_state.height * components.vae_scale_factor
+ block_state.width = block_state.width * components.vae_scale_factor
+
+ # control_guidance_start/control_guidance_end (align format)
+ if not isinstance(block_state.control_guidance_start, list) and isinstance(
+ block_state.control_guidance_end, list
+ ):
+ block_state.control_guidance_start = len(block_state.control_guidance_end) * [
+ block_state.control_guidance_start
+ ]
+ elif not isinstance(block_state.control_guidance_end, list) and isinstance(
+ block_state.control_guidance_start, list
+ ):
+ block_state.control_guidance_end = len(block_state.control_guidance_start) * [
+ block_state.control_guidance_end
+ ]
+
+ # guess_mode
+ block_state.global_pool_conditions = controlnet.config.global_pool_conditions
+ block_state.guess_mode = block_state.guess_mode or block_state.global_pool_conditions
+
+ # control_image
+ if not isinstance(block_state.control_image, list):
+ block_state.control_image = [block_state.control_image]
+ # control_mode
+ if not isinstance(block_state.control_mode, list):
+ block_state.control_mode = [block_state.control_mode]
+
+ if len(block_state.control_image) != len(block_state.control_mode):
+ raise ValueError("Expected len(control_image) == len(control_type)")
+
+ # control_type
+ block_state.num_control_type = controlnet.config.num_control_type
+ block_state.control_type = [0 for _ in range(block_state.num_control_type)]
+ for control_idx in block_state.control_mode:
+ block_state.control_type[control_idx] = 1
+ block_state.control_type = torch.Tensor(block_state.control_type)
+
+ block_state.control_type = block_state.control_type.reshape(1, -1).to(device, dtype=block_state.dtype)
+ repeat_by = block_state.batch_size * block_state.num_images_per_prompt // block_state.control_type.shape[0]
+ block_state.control_type = block_state.control_type.repeat_interleave(repeat_by, dim=0)
+
+ # prepare control_image
+ for idx, _ in enumerate(block_state.control_image):
+ block_state.control_image[idx] = self.prepare_control_image(
+ components,
+ image=block_state.control_image[idx],
+ width=block_state.width,
+ height=block_state.height,
+ batch_size=block_state.batch_size * block_state.num_images_per_prompt,
+ num_images_per_prompt=block_state.num_images_per_prompt,
+ device=device,
+ dtype=dtype,
+ crops_coords=block_state.crops_coords,
+ )
+ block_state.height, block_state.width = block_state.control_image[idx].shape[-2:]
+
+ # controlnet_keep
+ block_state.controlnet_keep = []
+ for i in range(len(block_state.timesteps)):
+ block_state.controlnet_keep.append(
+ 1.0
+ - float(
+ i / len(block_state.timesteps) < block_state.control_guidance_start
+ or (i + 1) / len(block_state.timesteps) > block_state.control_guidance_end
+ )
+ )
+ block_state.control_type_idx = block_state.control_mode
+ block_state.controlnet_cond = block_state.control_image
+ block_state.conditioning_scale = block_state.controlnet_conditioning_scale
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
new file mode 100644
index 0000000000..e9f627636e
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py
@@ -0,0 +1,218 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL
+from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
+from ...utils import logging
+from ..modular_pipeline import (
+ PipelineBlock,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class StableDiffusionXLDecodeStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("output_type", default="pil"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self->components
+ def upcast_vae(components):
+ dtype = components.vae.dtype
+ components.vae.to(dtype=torch.float32)
+ use_torch_2_0_or_xformers = isinstance(
+ components.vae.decoder.mid_block.attentions[0].processor,
+ (
+ AttnProcessor2_0,
+ XFormersAttnProcessor,
+ ),
+ )
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if use_torch_2_0_or_xformers:
+ components.vae.post_quant_conv.to(dtype)
+ components.vae.decoder.conv_in.to(dtype)
+ components.vae.decoder.mid_block.to(dtype)
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if not block_state.output_type == "latent":
+ latents = block_state.latents
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
+
+ if block_state.needs_upcasting:
+ self.upcast_vae(components)
+ latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
+ elif latents.dtype != components.vae.dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ components.vae = components.vae.to(latents.dtype)
+
+ # unscale/denormalize the latents
+ # denormalize with the mean and std if available and not None
+ block_state.has_latents_mean = (
+ hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
+ )
+ block_state.has_latents_std = (
+ hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
+ )
+ if block_state.has_latents_mean and block_state.has_latents_std:
+ block_state.latents_mean = (
+ torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ block_state.latents_std = (
+ torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+ )
+ latents = (
+ latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
+ )
+ else:
+ latents = latents / components.vae.config.scaling_factor
+
+ block_state.images = components.vae.decode(latents, return_dict=False)[0]
+
+ # cast back to fp16 if needed
+ if block_state.needs_upcasting:
+ components.vae.to(dtype=torch.float16)
+ else:
+ block_state.images = block_state.latents
+
+ # apply watermark if available
+ if hasattr(components, "watermark") and components.watermark is not None:
+ block_state.images = components.watermark.apply_watermark(block_state.images)
+
+ block_state.images = components.image_processor.postprocess(
+ block_state.images, output_type=block_state.output_type
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "A post-processing step that overlays the mask on the image (inpainting task only).\n"
+ + "only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("image"),
+ InputParam("mask_image"),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="The generated images from the decode step",
+ ),
+ InputParam(
+ "crops_coords",
+ type_hint=Tuple[int, int],
+ description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
+ block_state.images = [
+ components.image_processor.apply_overlay(
+ block_state.mask_image, block_state.image, i, block_state.crops_coords
+ )
+ for i in block_state.images
+ ]
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
new file mode 100644
index 0000000000..7fe4a472ee
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py
@@ -0,0 +1,791 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, List, Optional, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import ControlNetModel, UNet2DConditionModel
+from ...schedulers import EulerDiscreteScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ PipelineBlock,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# YiYi experimenting composible denoise loop
+# loop step (1): prepare latent input for denoiser
+class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepare the latent input for the denoiser. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
+
+ return components, block_state
+
+
+# loop step (1): prepare latent input for denoiser (with inpainting)
+class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that prepare the latent input for the denoiser (for inpainting workflow only). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object"
+ )
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "mask",
+ type_hint=Optional[torch.Tensor],
+ description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ InputParam(
+ "masked_image_latents",
+ type_hint=Optional[torch.Tensor],
+ description="The masked image latents to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ num_channels_unet = components.num_channels_unet
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ if block_state.mask is None or block_state.masked_image_latents is None:
+ raise ValueError("mask and masked_image_latents must be provided for inpainting-specific Unet")
+ num_channels_latents = block_state.latents.shape[1]
+ num_channels_mask = block_state.mask.shape[1]
+ num_channels_masked_image = block_state.masked_image_latents.shape[1]
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != num_channels_unet:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `components.unet`: {components.unet.config} expects"
+ f" {components.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
+ " `components.unet` or your `mask_image` or `image` input."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ self.check_inputs(components, block_state)
+
+ block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t)
+ if components.num_channels_unet == 9:
+ block_state.scaled_latents = torch.cat(
+ [block_state.scaled_latents, block_state.mask, block_state.masked_image_latents], dim=1
+ )
+
+ return components, block_state
+
+
+# loop step (2): denoise the latents with guidance
+class StableDiffusionXLLoopDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents with guidance. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("cross_attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "timestep_cond",
+ type_hint=Optional[torch.Tensor],
+ description="The guidance scale embedding to use for Latent Consistency Models(LCMs). Can be generated in prepare_additional_conditioning step.",
+ ),
+ InputParam(
+ kwargs_type="guider_input_fields",
+ description=(
+ "All conditional model inputs that need to be prepared with guider. "
+ "It should contain prompt_embeds/negative_prompt_embeds, "
+ "add_time_ids/negative_add_time_ids, "
+ "pooled_prompt_embeds/negative_pooled_prompt_embeds, "
+ "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
+ "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int
+ ) -> PipelineState:
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
+ guider_input_fields = {
+ "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
+ "time_ids": ("add_time_ids", "negative_add_time_ids"),
+ "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
+ "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
+ }
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
+ # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.unet)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ prompt_embeds = cond_kwargs.pop("prompt_embeds")
+
+ # Predict the noise residual
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.unet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=block_state.timestep_cond,
+ cross_attention_kwargs=block_state.cross_attention_kwargs,
+ added_cond_kwargs=cond_kwargs,
+ return_dict=False,
+ )[0]
+ components.guider.cleanup_models(components.unet)
+
+ # Perform guidance
+ block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+
+ return components, block_state
+
+
+# loop step (2): denoise the latents with guidance (with controlnet)
+class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec("controlnet", ControlNetModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that denoise the latents with guidance (with controlnet). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("cross_attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "controlnet_cond",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "conditioning_scale",
+ type_hint=float,
+ description="The controlnet conditioning scale value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "guess_mode",
+ required=True,
+ type_hint=bool,
+ description="The guess mode value to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "controlnet_keep",
+ required=True,
+ type_hint=List[float],
+ description="The controlnet keep values to use for the denoising process. Can be generated in prepare_controlnet_inputs step.",
+ ),
+ InputParam(
+ "timestep_cond",
+ type_hint=Optional[torch.Tensor],
+ description="The guidance scale embedding to use for Latent Consistency Models(LCMs), can be generated by prepare_additional_conditioning step",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="guider_input_fields",
+ description=(
+ "All conditional model inputs that need to be prepared with guider. "
+ "It should contain prompt_embeds/negative_prompt_embeds, "
+ "add_time_ids/negative_add_time_ids, "
+ "pooled_prompt_embeds/negative_pooled_prompt_embeds, "
+ "and ip_adapter_embeds/negative_ip_adapter_embeds (optional)."
+ "please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ InputParam(
+ kwargs_type="controlnet_kwargs",
+ description=(
+ "additional kwargs for controlnet (e.g. control_type_idx and control_type from the controlnet union input step )"
+ "please add `kwargs_type=controlnet_kwargs` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ ]
+
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ extra_controlnet_kwargs = self.prepare_extra_kwargs(
+ components.controlnet.forward, **block_state.controlnet_kwargs
+ )
+
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
+ guider_input_fields = {
+ "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
+ "time_ids": ("add_time_ids", "negative_add_time_ids"),
+ "text_embeds": ("pooled_prompt_embeds", "negative_pooled_prompt_embeds"),
+ "image_embeds": ("ip_adapter_embeds", "negative_ip_adapter_embeds"),
+ }
+
+ # cond_scale for the timestep (controlnet input)
+ if isinstance(block_state.controlnet_keep[i], list):
+ block_state.cond_scale = [
+ c * s for c, s in zip(block_state.conditioning_scale, block_state.controlnet_keep[i])
+ ]
+ else:
+ controlnet_cond_scale = block_state.conditioning_scale
+ if isinstance(controlnet_cond_scale, list):
+ controlnet_cond_scale = controlnet_cond_scale[0]
+ block_state.cond_scale = controlnet_cond_scale * block_state.controlnet_keep[i]
+
+ # default controlnet output/unet input for guess mode + conditional path
+ block_state.down_block_res_samples_zeros = None
+ block_state.mid_block_res_sample_zeros = None
+
+ # guided denoiser step
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
+ # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.unet)
+
+ # Prepare additional conditionings
+ added_cond_kwargs = {
+ "text_embeds": guider_state_batch.text_embeds,
+ "time_ids": guider_state_batch.time_ids,
+ }
+ if hasattr(guider_state_batch, "image_embeds") and guider_state_batch.image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = guider_state_batch.image_embeds
+
+ # Prepare controlnet additional conditionings
+ controlnet_added_cond_kwargs = {
+ "text_embeds": guider_state_batch.text_embeds,
+ "time_ids": guider_state_batch.time_ids,
+ }
+ # run controlnet for the guidance batch
+ if block_state.guess_mode and not components.guider.is_conditional:
+ # guider always run uncond batch first, so these tensors should be set already
+ down_block_res_samples = block_state.down_block_res_samples_zeros
+ mid_block_res_sample = block_state.mid_block_res_sample_zeros
+ else:
+ down_block_res_samples, mid_block_res_sample = components.controlnet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=guider_state_batch.prompt_embeds,
+ controlnet_cond=block_state.controlnet_cond,
+ conditioning_scale=block_state.cond_scale,
+ guess_mode=block_state.guess_mode,
+ added_cond_kwargs=controlnet_added_cond_kwargs,
+ return_dict=False,
+ **extra_controlnet_kwargs,
+ )
+
+ # assign it to block_state so it will be available for the uncond guidance batch
+ if block_state.down_block_res_samples_zeros is None:
+ block_state.down_block_res_samples_zeros = [torch.zeros_like(d) for d in down_block_res_samples]
+ if block_state.mid_block_res_sample_zeros is None:
+ block_state.mid_block_res_sample_zeros = torch.zeros_like(mid_block_res_sample)
+
+ # Predict the noise
+ # store the noise_pred in guider_state_batch so we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.unet(
+ block_state.scaled_latents,
+ t,
+ encoder_hidden_states=guider_state_batch.prompt_embeds,
+ timestep_cond=block_state.timestep_cond,
+ cross_attention_kwargs=block_state.cross_attention_kwargs,
+ added_cond_kwargs=added_cond_kwargs,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ return_dict=False,
+ )[0]
+ components.guider.cleanup_models(components.unet)
+
+ # Perform guidance
+ block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+
+ return components, block_state
+
+
+# loop step (3): scheduler step to update latents
+class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("eta", default=0.0),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ # YiYi TODO: move this out of here
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ block_state.extra_step_kwargs = self.prepare_extra_kwargs(
+ components.scheduler.step, generator=block_state.generator, eta=block_state.eta
+ )
+
+ # Perform scheduler step using the predicted output
+ block_state.latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ **block_state.extra_step_kwargs,
+ **block_state.scheduler_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != block_state.latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(block_state.latents_dtype)
+
+ return components, block_state
+
+
+# loop step (3): scheduler step to update latents (with inpainting)
+class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents (for inpainting workflow only). "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("eta", default=0.0),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "mask",
+ type_hint=Optional[torch.Tensor],
+ description="The mask to use for the denoising process, for inpainting task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ InputParam(
+ "noise",
+ type_hint=Optional[torch.Tensor],
+ description="The noise added to the image latents, for inpainting task only. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "image_latents",
+ type_hint=Optional[torch.Tensor],
+ description="The image latents to use for the denoising process, for inpainting/image-to-image task only. Can be generated in vae_encode or prepare_latent step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ @staticmethod
+ def prepare_extra_kwargs(func, exclude_kwargs=[], **kwargs):
+ accepted_kwargs = set(inspect.signature(func).parameters.keys())
+ extra_kwargs = {}
+ for key, value in kwargs.items():
+ if key in accepted_kwargs and key not in exclude_kwargs:
+ extra_kwargs[key] = value
+
+ return extra_kwargs
+
+ def check_inputs(self, components, block_state):
+ if components.num_channels_unet == 4:
+ if block_state.image_latents is None:
+ raise ValueError(f"image_latents is required for this step {self.__class__.__name__}")
+ if block_state.mask is None:
+ raise ValueError(f"mask is required for this step {self.__class__.__name__}")
+ if block_state.noise is None:
+ raise ValueError(f"noise is required for this step {self.__class__.__name__}")
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int):
+ self.check_inputs(components, block_state)
+
+ # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ block_state.extra_step_kwargs = self.prepare_extra_kwargs(
+ components.scheduler.step, generator=block_state.generator, eta=block_state.eta
+ )
+
+ # Perform scheduler step using the predicted output
+ block_state.latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred,
+ t,
+ block_state.latents,
+ **block_state.extra_step_kwargs,
+ **block_state.scheduler_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != block_state.latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ block_state.latents = block_state.latents.to(block_state.latents_dtype)
+
+ # adjust latent for inpainting
+ if components.num_channels_unet == 4:
+ block_state.init_latents_proper = block_state.image_latents
+ if i < len(block_state.timesteps) - 1:
+ block_state.noise_timestep = block_state.timesteps[i + 1]
+ block_state.init_latents_proper = components.scheduler.add_noise(
+ block_state.init_latents_proper, block_state.noise, torch.tensor([block_state.noise_timestep])
+ )
+
+ block_state.latents = (
+ 1 - block_state.mask
+ ) * block_state.init_latents_proper + block_state.mask * block_state.latents
+
+ return components, block_state
+
+
+# the loop wrapper that iterates over the timesteps
+class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("scheduler", EulerDiscreteScheduler),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ]
+
+ @property
+ def loop_intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False
+ if block_state.disable_guidance:
+ components.guider.disable()
+ else:
+ components.guider.enable()
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+# composing the denoising loops
+class StableDiffusionXLDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLLoopBeforeDenoiser,
+ StableDiffusionXLLoopDenoiser,
+ StableDiffusionXLLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `StableDiffusionXLLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLLoopDenoiser`\n"
+ " - `StableDiffusionXLLoopAfterDenoiser`\n"
+ "This block supports both text2img and img2img tasks."
+ )
+
+
+# control_cond
+class StableDiffusionXLControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLLoopBeforeDenoiser,
+ StableDiffusionXLControlNetLoopDenoiser,
+ StableDiffusionXLLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents with controlnet. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `StableDiffusionXLLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLControlNetLoopDenoiser`\n"
+ " - `StableDiffusionXLLoopAfterDenoiser`\n"
+ "This block supports using controlnet for both text2img and img2img tasks."
+ )
+
+
+# mask
+class StableDiffusionXLInpaintDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLInpaintLoopBeforeDenoiser,
+ StableDiffusionXLLoopDenoiser,
+ StableDiffusionXLInpaintLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents(for inpainting task only). \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLLoopDenoiser`\n"
+ " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
+ "This block onlysupports inpainting tasks."
+ )
+
+
+# control_cond + mask
+class StableDiffusionXLInpaintControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
+ block_classes = [
+ StableDiffusionXLInpaintLoopBeforeDenoiser,
+ StableDiffusionXLControlNetLoopDenoiser,
+ StableDiffusionXLInpaintLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents(for inpainting task only) with controlnet. \n"
+ "Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `StableDiffusionXLInpaintLoopBeforeDenoiser`\n"
+ " - `StableDiffusionXLControlNetLoopDenoiser`\n"
+ " - `StableDiffusionXLInpaintLoopAfterDenoiser`\n"
+ "This block only supports using controlnet for inpainting tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
new file mode 100644
index 0000000000..bd0e962140
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py
@@ -0,0 +1,902 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTextModelWithProjection,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+)
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from ...models.lora import adjust_lora_scale_text_encoder
+from ...utils import (
+ USE_PEFT_BACKEND,
+ logging,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import StableDiffusionXLModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class StableDiffusionXLIPAdapterStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return (
+ "IP Adapter step that prepares ip adapter image embeddings.\n"
+ "Note that this step only prepares the embeddings - in order for it to work correctly, "
+ "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n"
+ "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
+ " for more details"
+ )
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
+ ComponentSpec(
+ "feature_extractor",
+ CLIPImageProcessor,
+ config=FrozenDict({"size": 224, "crop_size": 224}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("unet", UNet2DConditionModel),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "ip_adapter_image",
+ PipelineImageInput,
+ required=True,
+ description="The image(s) to be used as ip adapter",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
+ OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=torch.Tensor,
+ description="Negative IP adapter image embeddings",
+ ),
+ ]
+
+ @staticmethod
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self->components
+ def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(components.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = components.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = components.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = components.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self,
+ components,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ num_images_per_prompt,
+ prepare_unconditional_embeds,
+ ):
+ image_embeds = []
+ if prepare_unconditional_embeds:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ components, single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if prepare_unconditional_embeds:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if prepare_unconditional_embeds:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if prepare_unconditional_embeds:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
+ block_state.device = components._execution_device
+
+ block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
+ components,
+ ip_adapter_image=block_state.ip_adapter_image,
+ ip_adapter_image_embeds=None,
+ device=block_state.device,
+ num_images_per_prompt=1,
+ prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
+ )
+ if block_state.prepare_unconditional_embeds:
+ block_state.negative_ip_adapter_embeds = []
+ for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
+ negative_image_embeds, image_embeds = image_embeds.chunk(2)
+ block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
+ block_state.ip_adapter_embeds[i] = image_embeds
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLTextEncoderStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the image generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", CLIPTextModel),
+ ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
+ ComponentSpec("tokenizer", CLIPTokenizer),
+ ComponentSpec("tokenizer_2", CLIPTokenizer),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 7.5}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return [ConfigSpec("force_zeros_for_empty_prompt", True)]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("prompt_2"),
+ InputParam("negative_prompt"),
+ InputParam("negative_prompt_2"),
+ InputParam("cross_attention_kwargs"),
+ InputParam("clip_skip"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="negative text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="pooled text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_pooled_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="negative pooled text embeddings used to guide the image generation",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ if block_state.prompt is not None and (
+ not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
+ ):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
+ elif block_state.prompt_2 is not None and (
+ not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)
+ ):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: str,
+ prompt_2: Optional[str] = None,
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prepare_unconditional_embeds: bool = True,
+ negative_prompt: Optional[str] = None,
+ negative_prompt_2: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in both text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prepare_unconditional_embeds (`bool`):
+ whether to use prepare unconditional embeddings or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ device = device or components._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
+ components._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if components.text_encoder is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
+ else:
+ scale_lora_layers(components.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # Define tokenizers and text encoders
+ tokenizers = (
+ [components.tokenizer, components.tokenizer_2]
+ if components.tokenizer is not None
+ else [components.tokenizer_2]
+ )
+ text_encoders = (
+ [components.text_encoder, components.text_encoder_2]
+ if components.text_encoder is not None
+ else [components.text_encoder_2]
+ )
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # textual inversion: process multi-vector tokens if necessary
+ prompt_embeds_list = []
+ prompts = [prompt, prompt_2]
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+ if isinstance(components, TextualInversionLoaderMixin):
+ prompt = components.maybe_convert_prompt(prompt, tokenizer)
+
+ text_inputs = tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ pooled_prompt_embeds = prompt_embeds[0]
+ if clip_skip is None:
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+ else:
+ # "2" because SDXL always indexes from the penultimate layer.
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+ prompt_embeds_list.append(prompt_embeds)
+
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+ # get unconditional embeddings for classifier free guidance
+ zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
+ if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+ elif prepare_unconditional_embeds and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+ # normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+ negative_prompt_2 = (
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+ )
+
+ uncond_tokens: List[str]
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = [negative_prompt, negative_prompt_2]
+
+ negative_prompt_embeds_list = []
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+ if isinstance(components, TextualInversionLoaderMixin):
+ negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = tokenizer(
+ negative_prompt,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ negative_prompt_embeds = text_encoder(
+ uncond_input.input_ids.to(device),
+ output_hidden_states=True,
+ )
+ # We are only ALWAYS interested in the pooled output of the final text encoder
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+ if components.text_encoder_2 is not None:
+ prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
+ else:
+ prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ if prepare_unconditional_embeds:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ if components.text_encoder_2 is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=components.text_encoder_2.dtype, device=device
+ )
+ else:
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+ if prepare_unconditional_embeds:
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+ bs_embed * num_images_per_prompt, -1
+ )
+
+ if components.text_encoder is not None:
+ if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder, lora_scale)
+
+ if components.text_encoder_2 is not None:
+ if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(components.text_encoder_2, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ block_state.text_encoder_lora_scale = (
+ block_state.cross_attention_kwargs.get("scale", None)
+ if block_state.cross_attention_kwargs is not None
+ else None
+ )
+ (
+ block_state.prompt_embeds,
+ block_state.negative_prompt_embeds,
+ block_state.pooled_prompt_embeds,
+ block_state.negative_pooled_prompt_embeds,
+ ) = self.encode_prompt(
+ components,
+ block_state.prompt,
+ block_state.prompt_2,
+ block_state.device,
+ 1,
+ block_state.prepare_unconditional_embeds,
+ block_state.negative_prompt,
+ block_state.negative_prompt_2,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ lora_scale=block_state.text_encoder_lora_scale,
+ clip_skip=block_state.clip_skip,
+ )
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class StableDiffusionXLVaeEncoderStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def description(self) -> str:
+ return "Vae Encoder step that encode the input image into a latent representation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("image", required=True),
+ InputParam("height"),
+ InputParam("width"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+ InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ InputParam(
+ "preprocess_kwargs",
+ type_hint=Optional[dict],
+ description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "image_latents",
+ type_hint=torch.Tensor,
+ description="The latents representing the reference image for image-to-image/inpainting generation",
+ )
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
+ block_state.device = components._execution_device
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+
+ block_state.image = components.image_processor.preprocess(
+ block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
+ )
+ block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
+
+ block_state.batch_size = block_state.image.shape[0]
+
+ # if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
+ if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
+ f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ block_state.image_latents = self._encode_vae_image(
+ components, image=block_state.image, generator=block_state.generator
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
+ model_name = "stable-diffusion-xl"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKL),
+ ComponentSpec(
+ "image_processor",
+ VaeImageProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec(
+ "mask_processor",
+ VaeImageProcessor,
+ config=FrozenDict(
+ {"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}
+ ),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Vae encoder step that prepares the image and mask for the inpainting process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height"),
+ InputParam("width"),
+ InputParam("image", required=True),
+ InputParam("mask_image", required=True),
+ InputParam("padding_mask_crop"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"
+ ),
+ OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
+ OutputParam(
+ "masked_image_latents",
+ type_hint=torch.Tensor,
+ description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)",
+ ),
+ OutputParam(
+ "crops_coords",
+ type_hint=Optional[Tuple[int, int]],
+ description="The crop coordinates to use for the preprocess/postprocess of the image and mask",
+ ),
+ ]
+
+ # Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
+ # YiYi TODO: update the _encode_vae_image so that we can use #Coped from
+ def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
+ latents_mean = latents_std = None
+ if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
+ latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
+ if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
+ latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
+
+ dtype = image.dtype
+ if components.vae.config.force_upcast:
+ image = image.float()
+ components.vae.to(dtype=torch.float32)
+
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
+
+ if components.vae.config.force_upcast:
+ components.vae.to(dtype)
+
+ image_latents = image_latents.to(dtype)
+ if latents_mean is not None and latents_std is not None:
+ latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
+ latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
+ image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
+ else:
+ image_latents = components.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ # modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
+ # do not accept do_classifier_free_guidance
+ def prepare_mask_latents(
+ self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+ if masked_image is not None and masked_image.shape[1] == 4:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = None
+
+ if masked_image is not None:
+ if masked_image_latents is None:
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
+
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+ return mask, masked_image_latents
+
+ @torch.no_grad()
+ def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
+ block_state.device = components._execution_device
+
+ if block_state.height is None:
+ block_state.height = components.default_height
+ if block_state.width is None:
+ block_state.width = components.default_width
+
+ if block_state.padding_mask_crop is not None:
+ block_state.crops_coords = components.mask_processor.get_crop_region(
+ block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop
+ )
+ block_state.resize_mode = "fill"
+ else:
+ block_state.crops_coords = None
+ block_state.resize_mode = "default"
+
+ block_state.image = components.image_processor.preprocess(
+ block_state.image,
+ height=block_state.height,
+ width=block_state.width,
+ crops_coords=block_state.crops_coords,
+ resize_mode=block_state.resize_mode,
+ )
+ block_state.image = block_state.image.to(dtype=torch.float32)
+
+ block_state.mask = components.mask_processor.preprocess(
+ block_state.mask_image,
+ height=block_state.height,
+ width=block_state.width,
+ resize_mode=block_state.resize_mode,
+ crops_coords=block_state.crops_coords,
+ )
+ block_state.masked_image = block_state.image * (block_state.mask < 0.5)
+
+ block_state.batch_size = block_state.image.shape[0]
+ block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
+ block_state.image_latents = self._encode_vae_image(
+ components, image=block_state.image, generator=block_state.generator
+ )
+
+ # 7. Prepare mask latent variables
+ block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
+ components,
+ block_state.mask,
+ block_state.masked_image,
+ block_state.batch_size,
+ block_state.height,
+ block_state.width,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
new file mode 100644
index 0000000000..c9033856bc
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py
@@ -0,0 +1,380 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ StableDiffusionXLControlNetInputStep,
+ StableDiffusionXLControlNetUnionInputStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ StableDiffusionXLImg2ImgPrepareLatentsStep,
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLInpaintPrepareLatentsStep,
+ StableDiffusionXLInputStep,
+ StableDiffusionXLPrepareAdditionalConditioningStep,
+ StableDiffusionXLPrepareLatentsStep,
+ StableDiffusionXLSetTimestepsStep,
+)
+from .decoders import (
+ StableDiffusionXLDecodeStep,
+ StableDiffusionXLInpaintOverlayMaskStep,
+)
+from .denoise import (
+ StableDiffusionXLControlNetDenoiseStep,
+ StableDiffusionXLDenoiseStep,
+ StableDiffusionXLInpaintControlNetDenoiseStep,
+ StableDiffusionXLInpaintDenoiseStep,
+)
+from .encoders import (
+ StableDiffusionXLInpaintVaeEncoderStep,
+ StableDiffusionXLIPAdapterStep,
+ StableDiffusionXLTextEncoderStep,
+ StableDiffusionXLVaeEncoderStep,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# auto blocks & sequential blocks & mappings
+
+
+# vae encoder (run before before_denoise)
+class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
+ block_names = ["inpaint", "img2img"]
+ block_trigger_inputs = ["mask_image", "image"]
+
+ @property
+ def description(self):
+ return (
+ "Vae encoder step that encode the image inputs into their latent representations.\n"
+ + "This is an auto pipeline block that works for both inpainting and img2img tasks.\n"
+ + " - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ + " - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
+ + " - if neither `mask_image` nor `image` is provided, step will be skipped."
+ )
+
+
+# optional ip-adapter (run before input step)
+class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLIPAdapterStep]
+ block_names = ["ip_adapter"]
+ block_trigger_inputs = ["ip_adapter_image"]
+
+ @property
+ def description(self):
+ return "Run IP Adapter step if `ip_adapter_image` is provided. This step should be placed before the 'input' step.\n"
+
+
+# before_denoise: text2img
+class StableDiffusionXLBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLSetTimestepsStep,
+ StableDiffusionXLPrepareLatentsStep,
+ StableDiffusionXLPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `StableDiffusionXLSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: img2img
+class StableDiffusionXLImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLImg2ImgPrepareLatentsStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step for img2img task.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLImg2ImgPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: inpainting
+class StableDiffusionXLInpaintBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInputStep,
+ StableDiffusionXLImg2ImgSetTimestepsStep,
+ StableDiffusionXLInpaintPrepareLatentsStep,
+ StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents", "prepare_add_cond"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step for inpainting task.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `StableDiffusionXLImg2ImgSetTimestepsStep` is used to set the timesteps\n"
+ + " - `StableDiffusionXLInpaintPrepareLatentsStep` is used to prepare the latents\n"
+ + " - `StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep` is used to prepare the additional conditioning\n"
+ )
+
+
+# before_denoise: all task (text2img, img2img, inpainting)
+class StableDiffusionXLAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLInpaintBeforeDenoiseStep,
+ StableDiffusionXLImg2ImgBeforeDenoiseStep,
+ StableDiffusionXLBeforeDenoiseStep,
+ ]
+ block_names = ["inpaint", "img2img", "text2img"]
+ block_trigger_inputs = ["mask", "image_latents", None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2img, img2img and inpainting tasks as well as controlnet, controlnet_union.\n"
+ + " - `StableDiffusionXLInpaintBeforeDenoiseStep` (inpaint) is used when both `mask` and `image_latents` are provided.\n"
+ + " - `StableDiffusionXLImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
+ + " - `StableDiffusionXLBeforeDenoiseStep` (text2img) is used when both `image_latents` and `mask` are not provided.\n"
+ )
+
+
+# optional controlnet input step (after before_denoise, before denoise)
+# works for both controlnet and controlnet_union
+class StableDiffusionXLAutoControlNetInputStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLControlNetUnionInputStep, StableDiffusionXLControlNetInputStep]
+ block_names = ["controlnet_union", "controlnet"]
+ block_trigger_inputs = ["control_mode", "control_image"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet Input step that prepare the controlnet input.\n"
+ + "This is an auto pipeline block that works for both controlnet and controlnet_union.\n"
+ + " (it should be called right before the denoise step)"
+ + " - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.\n"
+ + " - `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided."
+ + " - if neither `control_mode` nor `control_image` is provided, step will be skipped."
+ )
+
+
+# denoise: controlnet (text2img, img2img, inpainting)
+class StableDiffusionXLAutoControlNetDenoiseStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintControlNetDenoiseStep, StableDiffusionXLControlNetDenoiseStep]
+ block_names = ["inpaint_controlnet_denoise", "controlnet_denoise"]
+ block_trigger_inputs = ["mask", "controlnet_cond"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents with controlnet. "
+ "This is a auto pipeline block that using controlnet for text2img, img2img and inpainting tasks."
+ "This block should not be used without a controlnet_cond input"
+ " - `StableDiffusionXLInpaintControlNetDenoiseStep` (inpaint_controlnet_denoise) is used when mask is provided."
+ " - `StableDiffusionXLControlNetDenoiseStep` (controlnet_denoise) is used when mask is not provided but controlnet_cond is provided."
+ " - If neither mask nor controlnet_cond are provided, step will be skipped."
+ )
+
+
+# denoise: all task with or without controlnet (text2img, img2img, inpainting)
+class StableDiffusionXLAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLAutoControlNetDenoiseStep,
+ StableDiffusionXLInpaintDenoiseStep,
+ StableDiffusionXLDenoiseStep,
+ ]
+ block_names = ["controlnet_denoise", "inpaint_denoise", "denoise"]
+ block_trigger_inputs = ["controlnet_cond", "mask", None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2img, img2img and inpainting tasks. And can be used with or without controlnet."
+ " - `StableDiffusionXLAutoControlNetDenoiseStep` (controlnet_denoise) is used when controlnet_cond is provided (support controlnet withtext2img, img2img and inpainting tasks)."
+ " - `StableDiffusionXLInpaintDenoiseStep` (inpaint_denoise) is used when mask is provided (support inpainting tasks)."
+ " - `StableDiffusionXLDenoiseStep` (denoise) is used when neither mask nor controlnet_cond are provided (support text2img and img2img tasks)."
+ )
+
+
+# decode: inpaint
+class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
+ block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
+ block_names = ["decode", "mask_overlay"]
+
+ @property
+ def description(self):
+ return (
+ "Inpaint decode step that decode the denoised latents into images outputs.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n"
+ + " - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
+ )
+
+
+# decode: all task (text2img, img2img, inpainting)
+class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
+ block_names = ["inpaint", "non-inpaint"]
+ block_trigger_inputs = ["padding_mask_crop", None]
+
+ @property
+ def description(self):
+ return (
+ "Decode step that decode the denoised latents into images outputs.\n"
+ + "This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n"
+ + " - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n"
+ + " - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."
+ )
+
+
+# ip-adapter, controlnet, text2img, img2img, inpainting
+class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLTextEncoderStep,
+ StableDiffusionXLAutoIPAdapterStep,
+ StableDiffusionXLAutoVaeEncoderStep,
+ StableDiffusionXLAutoBeforeDenoiseStep,
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoDenoiseStep,
+ StableDiffusionXLAutoDecodeStep,
+ ]
+ block_names = [
+ "text_encoder",
+ "ip_adapter",
+ "image_encoder",
+ "before_denoise",
+ "controlnet_input",
+ "denoise",
+ "decoder",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
+ + "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
+ + "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ + "- to run the controlnet workflow, you need to provide `control_image`\n"
+ + "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ + "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ + "- for text-to-image generation, all you need to provide is `prompt`"
+ )
+
+
+# controlnet (input + denoise step)
+class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
+ block_classes = [
+ StableDiffusionXLAutoControlNetInputStep,
+ StableDiffusionXLAutoControlNetDenoiseStep,
+ ]
+ block_names = ["controlnet_input", "controlnet_denoise"]
+
+ @property
+ def description(self):
+ return (
+ "Controlnet auto step that prepare the controlnet input and denoise the latents. "
+ + "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
+ + " (it should be replace at 'denoise' step)"
+ )
+
+
+TEXT2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep),
+ ]
+)
+
+IMAGE2IMAGE_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("image_encoder", StableDiffusionXLVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLDenoiseStep),
+ ("decode", StableDiffusionXLDecodeStep),
+ ]
+)
+
+INPAINT_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
+ ("input", StableDiffusionXLInputStep),
+ ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
+ ("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
+ ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
+ ("denoise", StableDiffusionXLInpaintDenoiseStep),
+ ("decode", StableDiffusionXLInpaintDecodeStep),
+ ]
+)
+
+CONTROLNET_BLOCKS = InsertableDict(
+ [
+ ("denoise", StableDiffusionXLAutoControlnetStep),
+ ]
+)
+
+
+IP_ADAPTER_BLOCKS = InsertableDict(
+ [
+ ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
+ ]
+)
+
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", StableDiffusionXLTextEncoderStep),
+ ("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
+ ("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
+ ("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
+ ("controlnet_input", StableDiffusionXLAutoControlNetInputStep),
+ ("denoise", StableDiffusionXLAutoDenoiseStep),
+ ("decode", StableDiffusionXLAutoDecodeStep),
+ ]
+)
+
+
+ALL_BLOCKS = {
+ "text2img": TEXT2IMAGE_BLOCKS,
+ "img2img": IMAGE2IMAGE_BLOCKS,
+ "inpaint": INPAINT_BLOCKS,
+ "controlnet": CONTROLNET_BLOCKS,
+ "ip_adapter": IP_ADAPTER_BLOCKS,
+ "auto": AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
new file mode 100644
index 0000000000..fc030fae56
--- /dev/null
+++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py
@@ -0,0 +1,376 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...image_processor import PipelineImageInput
+from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
+from ...pipelines.pipeline_utils import StableDiffusionMixin
+from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+from ..modular_pipeline_utils import InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
+# YiYi Notes: model specific components:
+## (1) it should inherit from ModularPipeline
+## (2) acts like a container that holds components and configs
+## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
+## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
+## (5) how to use together with Components_manager?
+class StableDiffusionXLModularPipeline(
+ ModularPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ StableDiffusionXLLoraLoaderMixin,
+ ModularIPAdapterMixin,
+):
+ """
+ A ModularPipeline for Stable Diffusion XL.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ @property
+ def default_height(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_width(self):
+ return self.default_sample_size * self.vae_scale_factor
+
+ @property
+ def default_sample_size(self):
+ default_sample_size = 128
+ if hasattr(self, "unet") and self.unet is not None:
+ default_sample_size = self.unet.config.sample_size
+ return default_sample_size
+
+ @property
+ def vae_scale_factor(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ return vae_scale_factor
+
+ @property
+ def num_channels_unet(self):
+ num_channels_unet = 4
+ if hasattr(self, "unet") and self.unet is not None:
+ num_channels_unet = self.unet.config.in_channels
+ return num_channels_unet
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 4
+ if hasattr(self, "vae") and self.vae is not None:
+ num_channels_latents = self.vae.config.latent_channels
+ return num_channels_latents
+
+
+# YiYi/Sayak TODO: not used yet, maintain a list of schema that can be used across all pipeline blocks
+# auto_docstring
+SDXL_INPUTS_SCHEMA = {
+ "prompt": InputParam(
+ "prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
+ ),
+ "prompt_2": InputParam(
+ "prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
+ ),
+ "negative_prompt": InputParam(
+ "negative_prompt",
+ type_hint=Union[str, List[str]],
+ description="The prompt or prompts not to guide the image generation",
+ ),
+ "negative_prompt_2": InputParam(
+ "negative_prompt_2",
+ type_hint=Union[str, List[str]],
+ description="The negative prompt or prompts for text_encoder_2",
+ ),
+ "cross_attention_kwargs": InputParam(
+ "cross_attention_kwargs",
+ type_hint=Optional[dict],
+ description="Kwargs dictionary passed to the AttentionProcessor",
+ ),
+ "clip_skip": InputParam(
+ "clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
+ ),
+ "image": InputParam(
+ "image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="The image(s) to modify for img2img or inpainting",
+ ),
+ "mask_image": InputParam(
+ "mask_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Mask image for inpainting, white pixels will be repainted",
+ ),
+ "generator": InputParam(
+ "generator",
+ type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
+ description="Generator(s) for deterministic generation",
+ ),
+ "height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
+ "width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
+ "num_images_per_prompt": InputParam(
+ "num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
+ ),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
+ ),
+ "timesteps": InputParam(
+ "timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
+ ),
+ "sigmas": InputParam(
+ "sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
+ ),
+ "denoising_end": InputParam(
+ "denoising_end",
+ type_hint=Optional[float],
+ description="Fraction of denoising process to complete before termination",
+ ),
+ # YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
+ "strength": InputParam(
+ "strength", type_hint=float, default=0.3, description="How much to transform the reference image"
+ ),
+ "denoising_start": InputParam(
+ "denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
+ ),
+ "padding_mask_crop": InputParam(
+ "padding_mask_crop",
+ type_hint=Optional[Tuple[int, int]],
+ description="Size of margin in crop for image and mask",
+ ),
+ "original_size": InputParam(
+ "original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Original size of the image for SDXL's micro-conditioning",
+ ),
+ "target_size": InputParam(
+ "target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
+ ),
+ "negative_original_size": InputParam(
+ "negative_original_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on image resolution",
+ ),
+ "negative_target_size": InputParam(
+ "negative_target_size",
+ type_hint=Optional[Tuple[int, int]],
+ description="Negative conditioning based on target resolution",
+ ),
+ "crops_coords_top_left": InputParam(
+ "crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Top-left coordinates for SDXL's micro-conditioning",
+ ),
+ "negative_crops_coords_top_left": InputParam(
+ "negative_crops_coords_top_left",
+ type_hint=Tuple[int, int],
+ default=(0, 0),
+ description="Negative conditioning crop coordinates",
+ ),
+ "aesthetic_score": InputParam(
+ "aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
+ ),
+ "negative_aesthetic_score": InputParam(
+ "negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
+ ),
+ "eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
+ "output_type": InputParam(
+ "output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
+ ),
+ "ip_adapter_image": InputParam(
+ "ip_adapter_image",
+ type_hint=PipelineImageInput,
+ required=True,
+ description="Image(s) to be used as IP adapter",
+ ),
+ "control_image": InputParam(
+ "control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
+ ),
+ "control_guidance_start": InputParam(
+ "control_guidance_start",
+ type_hint=Union[float, List[float]],
+ default=0.0,
+ description="When ControlNet starts applying",
+ ),
+ "control_guidance_end": InputParam(
+ "control_guidance_end",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="When ControlNet stops applying",
+ ),
+ "controlnet_conditioning_scale": InputParam(
+ "controlnet_conditioning_scale",
+ type_hint=Union[float, List[float]],
+ default=1.0,
+ description="Scale factor for ControlNet outputs",
+ ),
+ "guess_mode": InputParam(
+ "guess_mode",
+ type_hint=bool,
+ default=False,
+ description="Enables ControlNet encoder to recognize input without prompts",
+ ),
+ "control_mode": InputParam(
+ "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
+ ),
+}
+
+
+SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
+ "prompt_embeds": InputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ required=True,
+ description="Text embeddings used to guide image generation",
+ ),
+ "negative_prompt_embeds": InputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": InputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": InputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
+ "dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "preprocess_kwargs": InputParam(
+ "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
+ ),
+ "latents": InputParam(
+ "latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
+ ),
+ "timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
+ "num_inference_steps": InputParam(
+ "num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
+ ),
+ "latent_timestep": InputParam(
+ "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
+ ),
+ "image_latents": InputParam(
+ "image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
+ ),
+ "mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
+ "masked_image_latents": InputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "add_time_ids": InputParam(
+ "add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
+ ),
+ "negative_add_time_ids": InputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "ip_adapter_embeds": InputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": InputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": InputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ required=True,
+ description="Generated images",
+ ),
+}
+
+
+SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
+ "prompt_embeds": OutputParam(
+ "prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"
+ ),
+ "negative_prompt_embeds": OutputParam(
+ "negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
+ ),
+ "pooled_prompt_embeds": OutputParam(
+ "pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"
+ ),
+ "negative_pooled_prompt_embeds": OutputParam(
+ "negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
+ ),
+ "batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
+ "dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
+ "image_latents": OutputParam(
+ "image_latents", type_hint=torch.Tensor, description="Latents representing reference image"
+ ),
+ "mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
+ "masked_image_latents": OutputParam(
+ "masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
+ ),
+ "crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
+ "timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
+ "num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
+ "latent_timestep": OutputParam(
+ "latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"
+ ),
+ "add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
+ "negative_add_time_ids": OutputParam(
+ "negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
+ ),
+ "timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
+ "latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
+ "noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
+ "ip_adapter_embeds": OutputParam(
+ "ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
+ ),
+ "negative_ip_adapter_embeds": OutputParam(
+ "negative_ip_adapter_embeds",
+ type_hint=List[torch.Tensor],
+ description="Negative image embeddings for IP-Adapter",
+ ),
+ "images": OutputParam(
+ "images",
+ type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
+ description="Generated images",
+ ),
+}
+
+
+SDXL_OUTPUTS_SCHEMA = {
+ "images": OutputParam(
+ "images",
+ type_hint=Union[
+ Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput
+ ],
+ description="The final generated images",
+ )
+}
diff --git a/src/diffusers/modular_pipelines/wan/__init__.py b/src/diffusers/modular_pipelines/wan/__init__.py
new file mode 100644
index 0000000000..7b548e003c
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/__init__.py
@@ -0,0 +1,66 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["encoders"] = ["WanTextEncoderStep"]
+ _import_structure["modular_blocks"] = [
+ "ALL_BLOCKS",
+ "AUTO_BLOCKS",
+ "TEXT2VIDEO_BLOCKS",
+ "WanAutoBeforeDenoiseStep",
+ "WanAutoBlocks",
+ "WanAutoBlocks",
+ "WanAutoDecodeStep",
+ "WanAutoDenoiseStep",
+ ]
+ _import_structure["modular_pipeline"] = ["WanModularPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
+ else:
+ from .encoders import WanTextEncoderStep
+ from .modular_blocks import (
+ ALL_BLOCKS,
+ AUTO_BLOCKS,
+ TEXT2VIDEO_BLOCKS,
+ WanAutoBeforeDenoiseStep,
+ WanAutoBlocks,
+ WanAutoDecodeStep,
+ WanAutoDenoiseStep,
+ )
+ from .modular_pipeline import WanModularPipeline
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/modular_pipelines/wan/before_denoise.py b/src/diffusers/modular_pipelines/wan/before_denoise.py
new file mode 100644
index 0000000000..ef65b64537
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/before_denoise.py
@@ -0,0 +1,365 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Union
+
+import torch
+
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import logging
+from ...utils.torch_utils import randn_tensor
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import WanModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# TODO(yiyi, aryan): We need another step before text encoder to set the `num_inference_steps` attribute for guider so that
+# things like when to do guidance and how many conditions to be prepared can be determined. Currently, this is done by
+# always assuming you want to do guidance in the Guiders. So, negative embeddings are prepared regardless of what the
+# configuration of guider is.
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class WanInputStep(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Input processing step that:\n"
+ " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
+ " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_videos_per_prompt`\n\n"
+ "All input tensors are expected to have either batch_size=1 or match the batch_size\n"
+ "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n"
+ "have a final batch_size of batch_size * num_videos_per_prompt."
+ )
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_videos_per_prompt", default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "prompt_embeds",
+ required=True,
+ type_hint=torch.Tensor,
+ description="Pre-generated text embeddings. Can be generated from text_encoder step.",
+ ),
+ InputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ description="Pre-generated negative text embeddings. Can be generated from text_encoder step.",
+ ),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "batch_size",
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt",
+ ),
+ OutputParam(
+ "dtype",
+ type_hint=torch.dtype,
+ description="Data type of model tensor inputs (determined by `prompt_embeds`)",
+ ),
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
+ description="negative text embeddings used to guide the image generation",
+ ),
+ ]
+
+ def check_inputs(self, components, block_state):
+ if block_state.prompt_embeds is not None and block_state.negative_prompt_embeds is not None:
+ if block_state.prompt_embeds.shape != block_state.negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {block_state.negative_prompt_embeds.shape}."
+ )
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ self.check_inputs(components, block_state)
+
+ block_state.batch_size = block_state.prompt_embeds.shape[0]
+ block_state.dtype = block_state.prompt_embeds.dtype
+
+ _, seq_len, _ = block_state.prompt_embeds.shape
+ block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_videos_per_prompt, 1)
+ block_state.prompt_embeds = block_state.prompt_embeds.view(
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
+ )
+
+ if block_state.negative_prompt_embeds is not None:
+ _, seq_len, _ = block_state.negative_prompt_embeds.shape
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
+ 1, block_state.num_videos_per_prompt, 1
+ )
+ block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
+ block_state.batch_size * block_state.num_videos_per_prompt, seq_len, -1
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class WanSetTimestepsStep(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that sets the scheduler's timesteps for inference"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("num_inference_steps", default=50),
+ InputParam("timesteps"),
+ InputParam("sigmas"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
+ OutputParam(
+ "num_inference_steps",
+ type_hint=int,
+ description="The number of denoising steps to perform at inference time",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ block_state.device = components._execution_device
+
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
+ components.scheduler,
+ block_state.num_inference_steps,
+ block_state.device,
+ block_state.timesteps,
+ block_state.sigmas,
+ )
+
+ self.set_block_state(state, block_state)
+ return components, state
+
+
+class WanPrepareLatentsStep(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return []
+
+ @property
+ def description(self) -> str:
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("height", type_hint=int),
+ InputParam("width", type_hint=int),
+ InputParam("num_frames", type_hint=int),
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam("generator"),
+ InputParam(
+ "batch_size",
+ required=True,
+ type_hint=int,
+ description="Number of prompts, the final batch size of model inputs should be `batch_size * num_videos_per_prompt`. Can be generated in input step.",
+ ),
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
+ )
+ ]
+
+ @staticmethod
+ def check_inputs(components, block_state):
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
+ ):
+ raise ValueError(
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
+ )
+ if block_state.num_frames is not None and (
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
+ ):
+ raise ValueError(
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
+ )
+
+ @staticmethod
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
+ def prepare_latents(
+ comp,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // comp.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // comp.vae_scale_factor_spatial,
+ int(width) // comp.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.height = block_state.height or components.default_height
+ block_state.width = block_state.width or components.default_width
+ block_state.num_frames = block_state.num_frames or components.default_num_frames
+ block_state.device = components._execution_device
+ block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
+ block_state.num_channels_latents = components.num_channels_latents
+
+ self.check_inputs(components, block_state)
+
+ block_state.latents = self.prepare_latents(
+ components,
+ block_state.batch_size * block_state.num_videos_per_prompt,
+ block_state.num_channels_latents,
+ block_state.height,
+ block_state.width,
+ block_state.num_frames,
+ block_state.dtype,
+ block_state.device,
+ block_state.generator,
+ block_state.latents,
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/decoders.py b/src/diffusers/modular_pipelines/wan/decoders.py
new file mode 100644
index 0000000000..4fadeed4b9
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/decoders.py
@@ -0,0 +1,105 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...models import AutoencoderKLWan
+from ...utils import logging
+from ...video_processor import VideoProcessor
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanDecodeStep(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("vae", AutoencoderKLWan),
+ ComponentSpec(
+ "video_processor",
+ VideoProcessor,
+ config=FrozenDict({"vae_scale_factor": 8}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def description(self) -> str:
+ return "Step that decodes the denoised latents into images"
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("output_type", default="pil"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The denoised latents from the denoising step",
+ )
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[str]:
+ return [
+ OutputParam(
+ "videos",
+ type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
+ )
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+ vae_dtype = components.vae.dtype
+
+ if not block_state.output_type == "latent":
+ latents = block_state.latents
+ latents_mean = (
+ torch.tensor(components.vae.config.latents_mean)
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
+ 1, components.vae.config.z_dim, 1, 1, 1
+ ).to(latents.device, latents.dtype)
+ latents = latents / latents_std + latents_mean
+ latents = latents.to(vae_dtype)
+ block_state.videos = components.vae.decode(latents, return_dict=False)[0]
+ else:
+ block_state.videos = block_state.latents
+
+ block_state.videos = components.video_processor.postprocess_video(
+ block_state.videos, output_type=block_state.output_type
+ )
+
+ self.set_block_state(state, block_state)
+
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/denoise.py b/src/diffusers/modular_pipelines/wan/denoise.py
new file mode 100644
index 0000000000..76c5cda5f9
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/denoise.py
@@ -0,0 +1,261 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, List, Tuple
+
+import torch
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...models import WanTransformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import logging
+from ..modular_pipeline import (
+ BlockState,
+ LoopSequentialPipelineBlocks,
+ PipelineBlock,
+ PipelineState,
+)
+from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
+from .modular_pipeline import WanModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanLoopDenoiser(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 5.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("transformer", WanTransformer3DModel),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Step within the denoising loop that denoise the latents with guidance. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return [
+ InputParam("attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam(
+ "latents",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ kwargs_type="guider_input_fields",
+ description=(
+ "All conditional model inputs that need to be prepared with guider. "
+ "It should contain prompt_embeds/negative_prompt_embeds. "
+ "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
+ ),
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(
+ self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
+ ) -> PipelineState:
+ # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds)
+ # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds)
+ guider_input_fields = {
+ "prompt_embeds": ("prompt_embeds", "negative_prompt_embeds"),
+ }
+ transformer_dtype = components.transformer.dtype
+
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
+
+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
+ # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
+ guider_state = components.guider.prepare_inputs(block_state, guider_input_fields)
+
+ # run the denoiser for each guidance batch
+ for guider_state_batch in guider_state:
+ components.guider.prepare_models(components.transformer)
+ cond_kwargs = guider_state_batch.as_dict()
+ cond_kwargs = {k: v for k, v in cond_kwargs.items() if k in guider_input_fields}
+ prompt_embeds = cond_kwargs.pop("prompt_embeds")
+
+ # Predict the noise residual
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
+ guider_state_batch.noise_pred = components.transformer(
+ hidden_states=block_state.latents.to(transformer_dtype),
+ timestep=t.flatten(),
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=block_state.attention_kwargs,
+ return_dict=False,
+ )[0]
+ components.guider.cleanup_models(components.transformer)
+
+ # Perform guidance
+ block_state.noise_pred, block_state.scheduler_step_kwargs = components.guider(guider_state)
+
+ return components, block_state
+
+
+class WanLoopAfterDenoiser(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ]
+
+ @property
+ def description(self) -> str:
+ return (
+ "step within the denoising loop that update the latents. "
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
+ "object (e.g. `WanDenoiseLoopWrapper`)"
+ )
+
+ @property
+ def inputs(self) -> List[Tuple[str, Any]]:
+ return []
+
+ @property
+ def intermediate_inputs(self) -> List[str]:
+ return [
+ InputParam("generator"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
+ # Perform scheduler step using the predicted output
+ latents_dtype = block_state.latents.dtype
+ block_state.latents = components.scheduler.step(
+ block_state.noise_pred.float(),
+ t,
+ block_state.latents.float(),
+ **block_state.scheduler_step_kwargs,
+ return_dict=False,
+ )[0]
+
+ if block_state.latents.dtype != latents_dtype:
+ block_state.latents = block_state.latents.to(latents_dtype)
+
+ return components, block_state
+
+
+class WanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return (
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
+ )
+
+ @property
+ def loop_expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 5.0}),
+ default_creation_method="from_config",
+ ),
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
+ ComponentSpec("transformer", WanTransformer3DModel),
+ ]
+
+ @property
+ def loop_intermediate_inputs(self) -> List[InputParam]:
+ return [
+ InputParam(
+ "timesteps",
+ required=True,
+ type_hint=torch.Tensor,
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ InputParam(
+ "num_inference_steps",
+ required=True,
+ type_hint=int,
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
+ ),
+ ]
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ block_state = self.get_block_state(state)
+
+ block_state.num_warmup_steps = max(
+ len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
+ )
+
+ with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
+ for i, t in enumerate(block_state.timesteps):
+ components, block_state = self.loop_step(components, block_state, i=i, t=t)
+ if i == len(block_state.timesteps) - 1 or (
+ (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ self.set_block_state(state, block_state)
+
+ return components, state
+
+
+class WanDenoiseStep(WanDenoiseLoopWrapper):
+ block_classes = [
+ WanLoopDenoiser,
+ WanLoopAfterDenoiser,
+ ]
+ block_names = ["before_denoiser", "denoiser", "after_denoiser"]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. \n"
+ "Its loop logic is defined in `WanDenoiseLoopWrapper.__call__` method \n"
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
+ " - `WanLoopDenoiser`\n"
+ " - `WanLoopAfterDenoiser`\n"
+ "This block supports both text2vid tasks."
+ )
diff --git a/src/diffusers/modular_pipelines/wan/encoders.py b/src/diffusers/modular_pipelines/wan/encoders.py
new file mode 100644
index 0000000000..b2ecfd1aa6
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/encoders.py
@@ -0,0 +1,242 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import List, Optional, Union
+
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...configuration_utils import FrozenDict
+from ...guiders import ClassifierFreeGuidance
+from ...utils import is_ftfy_available, logging
+from ..modular_pipeline import PipelineBlock, PipelineState
+from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
+from .modular_pipeline import WanModularPipeline
+
+
+if is_ftfy_available():
+ import ftfy
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class WanTextEncoderStep(PipelineBlock):
+ model_name = "wan"
+
+ @property
+ def description(self) -> str:
+ return "Text Encoder step that generate text_embeddings to guide the video generation"
+
+ @property
+ def expected_components(self) -> List[ComponentSpec]:
+ return [
+ ComponentSpec("text_encoder", UMT5EncoderModel),
+ ComponentSpec("tokenizer", AutoTokenizer),
+ ComponentSpec(
+ "guider",
+ ClassifierFreeGuidance,
+ config=FrozenDict({"guidance_scale": 5.0}),
+ default_creation_method="from_config",
+ ),
+ ]
+
+ @property
+ def expected_configs(self) -> List[ConfigSpec]:
+ return []
+
+ @property
+ def inputs(self) -> List[InputParam]:
+ return [
+ InputParam("prompt"),
+ InputParam("negative_prompt"),
+ InputParam("attention_kwargs"),
+ ]
+
+ @property
+ def intermediate_outputs(self) -> List[OutputParam]:
+ return [
+ OutputParam(
+ "prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="text embeddings used to guide the image generation",
+ ),
+ OutputParam(
+ "negative_prompt_embeds",
+ type_hint=torch.Tensor,
+ kwargs_type="guider_input_fields",
+ description="negative text embeddings used to guide the image generation",
+ ),
+ ]
+
+ @staticmethod
+ def check_inputs(block_state):
+ if block_state.prompt is not None and (
+ not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
+ ):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
+
+ @staticmethod
+ def _get_t5_prompt_embeds(
+ components,
+ prompt: Union[str, List[str]],
+ max_sequence_length: int,
+ device: torch.device,
+ ):
+ dtype = components.text_encoder.dtype
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+
+ text_inputs = components.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+ prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ return prompt_embeds
+
+ @staticmethod
+ def encode_prompt(
+ components,
+ prompt: str,
+ device: Optional[torch.device] = None,
+ num_videos_per_prompt: int = 1,
+ prepare_unconditional_embeds: bool = True,
+ negative_prompt: Optional[str] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_videos_per_prompt (`int`):
+ number of videos that should be generated per prompt
+ prepare_unconditional_embeds (`bool`):
+ whether to use prepare unconditional embeddings or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ max_sequence_length (`int`, defaults to `512`):
+ The maximum number of text tokens to be used for the generation process.
+ """
+ device = device or components._execution_device
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
+
+ if prepare_unconditional_embeds and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
+ components, negative_prompt, max_sequence_length, device
+ )
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
+
+ if prepare_unconditional_embeds:
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ @torch.no_grad()
+ def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
+ # Get inputs and intermediates
+ block_state = self.get_block_state(state)
+ self.check_inputs(block_state)
+
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
+ block_state.device = components._execution_device
+
+ # Encode input prompt
+ (
+ block_state.prompt_embeds,
+ block_state.negative_prompt_embeds,
+ ) = self.encode_prompt(
+ components,
+ block_state.prompt,
+ block_state.device,
+ 1,
+ block_state.prepare_unconditional_embeds,
+ block_state.negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ )
+
+ # Add outputs
+ self.set_block_state(state, block_state)
+ return components, state
diff --git a/src/diffusers/modular_pipelines/wan/modular_blocks.py b/src/diffusers/modular_pipelines/wan/modular_blocks.py
new file mode 100644
index 0000000000..5f4c1a9835
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/modular_blocks.py
@@ -0,0 +1,144 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...utils import logging
+from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
+from ..modular_pipeline_utils import InsertableDict
+from .before_denoise import (
+ WanInputStep,
+ WanPrepareLatentsStep,
+ WanSetTimestepsStep,
+)
+from .decoders import WanDecodeStep
+from .denoise import WanDenoiseStep
+from .encoders import WanTextEncoderStep
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+# before_denoise: text2vid
+class WanBeforeDenoiseStep(SequentialPipelineBlocks):
+ block_classes = [
+ WanInputStep,
+ WanSetTimestepsStep,
+ WanPrepareLatentsStep,
+ ]
+ block_names = ["input", "set_timesteps", "prepare_latents"]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is a sequential pipeline blocks:\n"
+ + " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
+ + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ )
+
+
+# before_denoise: all task (text2vid,)
+class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ WanBeforeDenoiseStep,
+ ]
+ block_names = ["text2vid"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self):
+ return (
+ "Before denoise step that prepare the inputs for the denoise step.\n"
+ + "This is an auto pipeline block that works for text2vid.\n"
+ + " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
+ )
+
+
+# denoise: text2vid
+class WanAutoDenoiseStep(AutoPipelineBlocks):
+ block_classes = [
+ WanDenoiseStep,
+ ]
+ block_names = ["denoise"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self) -> str:
+ return (
+ "Denoise step that iteratively denoise the latents. "
+ "This is a auto pipeline block that works for text2vid tasks.."
+ " - `WanDenoiseStep` (denoise) for text2vid tasks."
+ )
+
+
+# decode: all task (text2img, img2img, inpainting)
+class WanAutoDecodeStep(AutoPipelineBlocks):
+ block_classes = [WanDecodeStep]
+ block_names = ["non-inpaint"]
+ block_trigger_inputs = [None]
+
+ @property
+ def description(self):
+ return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
+
+
+# text2vid
+class WanAutoBlocks(SequentialPipelineBlocks):
+ block_classes = [
+ WanTextEncoderStep,
+ WanAutoBeforeDenoiseStep,
+ WanAutoDenoiseStep,
+ WanAutoDecodeStep,
+ ]
+ block_names = [
+ "text_encoder",
+ "before_denoise",
+ "denoise",
+ "decoder",
+ ]
+
+ @property
+ def description(self):
+ return (
+ "Auto Modular pipeline for text-to-video using Wan.\n"
+ + "- for text-to-video generation, all you need to provide is `prompt`"
+ )
+
+
+TEXT2VIDEO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("input", WanInputStep),
+ ("set_timesteps", WanSetTimestepsStep),
+ ("prepare_latents", WanPrepareLatentsStep),
+ ("denoise", WanDenoiseStep),
+ ("decode", WanDecodeStep),
+ ]
+)
+
+
+AUTO_BLOCKS = InsertableDict(
+ [
+ ("text_encoder", WanTextEncoderStep),
+ ("before_denoise", WanAutoBeforeDenoiseStep),
+ ("denoise", WanAutoDenoiseStep),
+ ("decode", WanAutoDecodeStep),
+ ]
+)
+
+
+ALL_BLOCKS = {
+ "text2video": TEXT2VIDEO_BLOCKS,
+ "auto": AUTO_BLOCKS,
+}
diff --git a/src/diffusers/modular_pipelines/wan/modular_pipeline.py b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
new file mode 100644
index 0000000000..4d86e0d08e
--- /dev/null
+++ b/src/diffusers/modular_pipelines/wan/modular_pipeline.py
@@ -0,0 +1,90 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ...loaders import WanLoraLoaderMixin
+from ...pipelines.pipeline_utils import StableDiffusionMixin
+from ...utils import logging
+from ..modular_pipeline import ModularPipeline
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class WanModularPipeline(
+ ModularPipeline,
+ StableDiffusionMixin,
+ WanLoraLoaderMixin,
+):
+ """
+ A ModularPipeline for Wan.
+
+
+
+ This is an experimental feature and is likely to change in the future.
+
+
+ """
+
+ @property
+ def default_height(self):
+ return self.default_sample_height * self.vae_scale_factor_spatial
+
+ @property
+ def default_width(self):
+ return self.default_sample_width * self.vae_scale_factor_spatial
+
+ @property
+ def default_num_frames(self):
+ return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
+
+ @property
+ def default_sample_height(self):
+ return 60
+
+ @property
+ def default_sample_width(self):
+ return 104
+
+ @property
+ def default_sample_num_frames(self):
+ return 21
+
+ @property
+ def vae_scale_factor_spatial(self):
+ vae_scale_factor = 8
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def vae_scale_factor_temporal(self):
+ vae_scale_factor = 4
+ if hasattr(self, "vae") and self.vae is not None:
+ vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
+ return vae_scale_factor
+
+ @property
+ def num_channels_transformer(self):
+ num_channels_transformer = 16
+ if hasattr(self, "transformer") and self.transformer is not None:
+ num_channels_transformer = self.transformer.config.in_channels
+ return num_channels_transformer
+
+ @property
+ def num_channels_latents(self):
+ num_channels_latents = 16
+ if hasattr(self, "vae") and self.vae is not None:
+ num_channels_latents = self.vae.config.z_dim
+ return num_channels_latents
diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md
index b0a8a54b14..363caffe20 100644
--- a/src/diffusers/pipelines/README.md
+++ b/src/diffusers/pipelines/README.md
@@ -86,7 +86,7 @@ logic including pre-processing, an unrolled diffusion loop, and post-processing
### Text-to-Image generation with Stable Diffusion
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index b32d55bd51..c8fbdf0c6c 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -140,6 +140,8 @@ else:
"FluxFillPipeline",
"FluxPriorReduxPipeline",
"ReduxImageEncoder",
+ "FluxKontextPipeline",
+ "FluxKontextInpaintPipeline",
]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
@@ -378,6 +380,13 @@ else:
"WuerstchenPriorPipeline",
]
_import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"]
+ _import_structure["skyreels_v2"] = [
+ "SkyReelsV2DiffusionForcingPipeline",
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline",
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline",
+ "SkyReelsV2ImageToVideoPipeline",
+ "SkyReelsV2Pipeline",
+ ]
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
@@ -609,6 +618,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxFillPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
ReduxImageEncoder,
@@ -847,6 +858,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SpectrogramDiffusionPipeline,
)
+ from .skyreels_v2 import (
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Pipeline,
+ )
+
else:
import sys
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index b1a7ffaaea..ebabf17995 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -49,6 +49,7 @@ from .flux import (
FluxControlPipeline,
FluxImg2ImgPipeline,
FluxInpaintPipeline,
+ FluxKontextPipeline,
FluxPipeline,
)
from .hunyuandit import HunyuanDiTPipeline
@@ -142,6 +143,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxPipeline),
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
+ ("flux-kontext", FluxKontextPipeline),
("lumina", LuminaPipeline),
("lumina2", Lumina2Pipeline),
("chroma", ChromaPipeline),
@@ -171,6 +173,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxImg2ImgPipeline),
("flux-controlnet", FluxControlNetImg2ImgPipeline),
("flux-control", FluxControlImg2ImgPipeline),
+ ("flux-kontext", FluxKontextPipeline),
]
)
@@ -248,14 +251,15 @@ def _get_connected_pipeline(pipeline_cls):
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
-def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
- def get_model(pipeline_class_name):
- for task_mapping in SUPPORTED_TASKS_MAPPINGS:
- for model_name, pipeline in task_mapping.items():
- if pipeline.__name__ == pipeline_class_name:
- return model_name
+def _get_model(pipeline_class_name):
+ for task_mapping in SUPPORTED_TASKS_MAPPINGS:
+ for model_name, pipeline in task_mapping.items():
+ if pipeline.__name__ == pipeline_class_name:
+ return model_name
- model_name = get_model(pipeline_class_name)
+
+def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
+ model_name = _get_model(pipeline_class_name)
if model_name is not None:
task_class = mapping.get(model_name, None)
@@ -391,8 +395,8 @@ class AutoPipelineForText2Image(ConfigMixin):
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
+ auth login`.
@@ -686,8 +690,8 @@ class AutoPipelineForImage2Image(ConfigMixin):
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
+ auth login`.
@@ -996,8 +1000,8 @@ class AutoPipelineForInpainting(ConfigMixin):
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
+ auth login`.
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py
index c74834ee82..3a34ec2a42 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py
@@ -663,11 +663,11 @@ class ChromaPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
index 9936608aaf..e169db4a4d 100644
--- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
+++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py
@@ -725,11 +725,11 @@ class ChromaImg2ImgPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
strength (`float, *optional*, defaults to 0.9):
Conceptually, indicates how much to transform the reference image. Must be between 0 and 1. image will
be used as a starting point, adding more noise to it the larger the strength. The number of denoising
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
index f08a3c35c2..3c5994172c 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
@@ -718,14 +718,15 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
index fe3e8ae388..cf6ccebc47 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py
@@ -784,14 +784,15 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
index a982f4b275..d1f02ca9c9 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py
@@ -831,15 +831,16 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- ofs=ofs_emb,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ ofs=ofs_emb,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
index 7c50bdcb7d..230c8ca296 100644
--- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
+++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py
@@ -799,14 +799,15 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
timestep = t.expand(latent_model_input.shape[0])
# predict noise model_output
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- image_rotary_emb=image_rotary_emb,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ image_rotary_emb=image_rotary_emb,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
# perform guidance
diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
index 880253459e..d8374b694f 100644
--- a/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
+++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4.py
@@ -619,22 +619,10 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
- noise_pred_cond = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- original_size=original_size,
- target_size=target_size,
- crop_coords=crops_coords_top_left,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- # perform guidance
- if self.do_classifier_free_guidance:
- noise_pred_uncond = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
timestep=timestep,
original_size=original_size,
target_size=target_size,
@@ -643,6 +631,19 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
return_dict=False,
)[0]
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ with self.transformer.cache_context("uncond"):
+ noise_pred_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=negative_prompt_embeds,
+ timestep=timestep,
+ original_size=original_size,
+ target_size=target_size,
+ crop_coords=crops_coords_top_left,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
index 7d6a29ceca..598e3b5b6d 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py
@@ -29,7 +29,7 @@ from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from ..pipeline_utils import DeprecatedPipelineMixin, DiffusionPipeline, ImagePipelineOutput
if is_torch_xla_available():
@@ -88,7 +88,7 @@ EXAMPLE_DOC_STRING = """
"""
-class BlipDiffusionControlNetPipeline(DiffusionPipeline):
+class BlipDiffusionControlNetPipeline(DeprecatedPipelineMixin, DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
@@ -116,6 +116,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
+ _last_supported_version = "0.33.1"
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
index fd490e1d5d..7fa59395a8 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py
@@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
-import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -35,7 +34,13 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
+from ...models import (
+ AutoencoderKL,
+ ControlNetUnionModel,
+ ImageProjection,
+ MultiControlNetUnionModel,
+ UNet2DConditionModel,
+)
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -230,7 +235,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
- controlnet: ControlNetUnionModel,
+ controlnet: Union[
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
+ ],
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
@@ -240,8 +247,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
):
super().__init__()
- if not isinstance(controlnet, ControlNetUnionModel):
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
@@ -660,6 +667,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ control_mode=None,
callback_on_step_end_tensor_inputs=None,
padding_mask_crop=None,
):
@@ -747,25 +755,34 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
- # Check `image`
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
- )
- if (
- isinstance(self.controlnet, ControlNetModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- elif (
- isinstance(self.controlnet, ControlNetUnionModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
- else:
- assert False
+ # Check `image`
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif not all(isinstance(i, list) for i in image):
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for images_ in image:
+ for image_ in images_:
+ self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
@@ -778,6 +795,12 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
)
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
for start, end in zip(control_guidance_start, control_guidance_end):
if start >= end:
raise ValueError(
@@ -788,6 +811,28 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+ # Check `control_mode`
+ if isinstance(controlnet, ControlNetUnionModel):
+ if max(control_mode) >= controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
+ if max(_control_mode) >= _controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
+
+ # Equal number of `image` and `control_mode` elements
+ if isinstance(controlnet, ControlNetUnionModel):
+ if len(image) != len(control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not all(isinstance(i, list) for i in control_mode):
+ raise ValueError(
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
+ )
+
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
+ raise ValueError("Expected len(control_image) == len(control_mode)")
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1117,7 +1162,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
- control_image: PipelineImageInput = None,
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
@@ -1145,7 +1190,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
- control_mode: Optional[Union[int, List[int]]] = None,
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1177,6 +1222,13 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -1269,6 +1321,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the ControlNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
+ conditions in control_image.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1333,22 +1401,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
- # align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
- # # 0.0 Default height and width to unet
- # height = height or self.unet.config.sample_size * self.vae_scale_factor
- # width = width or self.unet.config.sample_size * self.vae_scale_factor
-
- # 0.1 align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1357,40 +1409,59 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
- if len(control_image) != len(control_mode):
- raise ValueError("Expected len(control_image) == len(control_type)")
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_image = [[item] for item in control_image]
+ control_mode = [[item] for item in control_mode]
- num_control_type = controlnet.config.num_control_type
-
- # 1. Check inputs
- control_type = [0 for _ in range(num_control_type)]
- for _image, control_idx in zip(control_image, control_mode):
- control_type[control_idx] = 1
- self.check_inputs(
- prompt,
- prompt_2,
- _image,
- mask_image,
- strength,
- num_inference_steps,
- callback_steps,
- output_type,
- negative_prompt,
- negative_prompt_2,
- prompt_embeds,
- negative_prompt_embeds,
- ip_adapter_image,
- ip_adapter_image_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- controlnet_conditioning_scale,
- control_guidance_start,
- control_guidance_end,
- callback_on_step_end_tensor_inputs,
- padding_mask_crop,
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
)
- control_type = torch.Tensor(control_type)
+ if isinstance(controlnet_conditioning_scale, float):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ mask_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ output_type,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ control_mode,
+ callback_on_step_end_tensor_inputs,
+ padding_mask_crop,
+ )
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
+ ]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
@@ -1483,21 +1554,55 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
init_image = init_image.to(dtype=torch.float32)
# 5.2 Prepare control images
- for idx, _ in enumerate(control_image):
- control_image[idx] = self.prepare_control_image(
- image=control_image[idx],
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- crops_coords=crops_coords,
- resize_mode=resize_mode,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
- guess_mode=guess_mode,
- )
- height, width = control_image[idx].shape[-2:]
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ images = []
+
+ for image_ in control_image_:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+ control_images.append(images)
+
+ control_image = control_images
+ height, width = control_image[0][0].shape[-2:]
# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
@@ -1559,10 +1664,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
# 8.2 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
- controlnet_keep.append(
- 1.0
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
- )
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps)
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
height, width = latents.shape[-2:]
@@ -1627,11 +1733,24 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
timesteps = timesteps[:num_inference_steps]
- control_type = (
- control_type.reshape(1, -1)
- .to(device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ )
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ _control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ for _control_type in control_type
+ ]
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
index 60768e43fa..5961d389ef 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py
@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline(
is_controlnet_compiled = is_compiled_module(self.controlnet)
is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
+ )
+
if isinstance(controlnet, ControlNetUnionModel):
control_type = (
control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ .repeat(control_type_repeat_factor, 1)
)
- if isinstance(controlnet, MultiControlNetUnionModel):
+ elif isinstance(controlnet, MultiControlNetUnionModel):
control_type = [
_control_type.reshape(1, -1)
.to(self._execution_device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+ .repeat(control_type_repeat_factor, 1)
for _control_type in control_type
]
diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
index 82ef4b6391..65e2fe6617 100644
--- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
+++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py
@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL.Image
import torch
-import torch.nn.functional as F
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
@@ -38,7 +37,13 @@ from ...loaders import (
StableDiffusionXLLoraLoaderMixin,
TextualInversionLoaderMixin,
)
-from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
+from ...models import (
+ AutoencoderKL,
+ ControlNetUnionModel,
+ ImageProjection,
+ MultiControlNetUnionModel,
+ UNet2DConditionModel,
+)
from ...models.attention_processor import (
AttnProcessor2_0,
XFormersAttnProcessor,
@@ -262,7 +267,9 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: UNet2DConditionModel,
- controlnet: ControlNetUnionModel,
+ controlnet: Union[
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
+ ],
scheduler: KarrasDiffusionSchedulers,
requires_aesthetics_score: bool = False,
force_zeros_for_empty_prompt: bool = True,
@@ -272,8 +279,8 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
):
super().__init__()
- if not isinstance(controlnet, ControlNetUnionModel):
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
+ if isinstance(controlnet, (list, tuple)):
+ controlnet = MultiControlNetUnionModel(controlnet)
self.register_modules(
vae=vae,
@@ -649,6 +656,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet_conditioning_scale=1.0,
control_guidance_start=0.0,
control_guidance_end=1.0,
+ control_mode=None,
callback_on_step_end_tensor_inputs=None,
):
if strength < 0 or strength > 1:
@@ -722,28 +730,44 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ # `prompt` needs more sophisticated handling when there are multiple
+ # conditionings.
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
+ if isinstance(prompt, list):
+ logger.warning(
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
+ " prompts. The conditionings will be fixed across the prompts."
+ )
+
# Check `image`
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
- )
- if (
- isinstance(self.controlnet, ControlNetModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- elif (
- isinstance(self.controlnet, ControlNetUnionModel)
- or is_compiled
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
- ):
- self.check_image(image, prompt, prompt_embeds)
- else:
- assert False
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ for image_ in image:
+ self.check_image(image_, prompt, prompt_embeds)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ if not isinstance(image, list):
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
+ elif not all(isinstance(i, list) for i in image):
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
+ elif len(image) != len(self.controlnet.nets):
+ raise ValueError(
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
+ )
+
+ for images_ in image:
+ for image_ in images_:
+ self.check_image(image_, prompt, prompt_embeds)
if not isinstance(control_guidance_start, (tuple, list)):
control_guidance_start = [control_guidance_start]
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ if len(control_guidance_start) != len(self.controlnet.nets):
+ raise ValueError(
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
+ )
+
if not isinstance(control_guidance_end, (tuple, list)):
control_guidance_end = [control_guidance_end]
@@ -762,6 +786,15 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+ # Check `control_mode`
+ if isinstance(controlnet, ControlNetUnionModel):
+ if max(control_mode) >= controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
+ if max(_control_mode) >= _controlnet.config.num_control_type:
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
+
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
"Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1049,7 +1082,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
- control_image: PipelineImageInput = None,
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
@@ -1074,7 +1107,7 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
- control_mode: Optional[Union[int, List[int]]] = None,
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Tuple[int, int] = None,
@@ -1104,13 +1137,13 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The initial image will be used as the starting point for the image generation process. Can also accept
image latents as `image`, if passing latents directly, it will not be encoded again.
- control_image (`PipelineImageInput`):
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
- the type is specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also
- be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
- and/or width are passed, `image` is resized according to them. If multiple ControlNets are specified in
- init, images must be passed as a list such that each element of the list can be correctly batched for
- input to a single controlnet.
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
+ images must be passed as a list such that each element of the list can be correctly batched for input
+ to a single ControlNet.
height (`int`, *optional*, defaults to the size of control_image):
The height in pixels of the generated image. Anything below 512 pixels won't work well for
[stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
@@ -1184,16 +1217,21 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
- to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
- corresponding scale as a list.
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
+ the corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
- The percentage of total steps at which the controlnet starts applying.
+ The percentage of total steps at which the ControlNet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
- The percentage of total steps at which the controlnet stops applying.
+ The percentage of total steps at which the ControlNet stops applying.
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
+ conditions in control_image
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1273,12 +1311,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
- # align format for control guidance
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
-
if not isinstance(control_image, list):
control_image = [control_image]
else:
@@ -1287,37 +1319,56 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
if not isinstance(control_mode, list):
control_mode = [control_mode]
- if len(control_image) != len(control_mode):
- raise ValueError("Expected len(control_image) == len(control_type)")
+ if isinstance(controlnet, MultiControlNetUnionModel):
+ control_image = [[item] for item in control_image]
+ control_mode = [[item] for item in control_mode]
- num_control_type = controlnet.config.num_control_type
-
- # 1. Check inputs
- control_type = [0 for _ in range(num_control_type)]
- for _image, control_idx in zip(control_image, control_mode):
- control_type[control_idx] = 1
- self.check_inputs(
- prompt,
- prompt_2,
- _image,
- strength,
- num_inference_steps,
- callback_steps,
- negative_prompt,
- negative_prompt_2,
- prompt_embeds,
- negative_prompt_embeds,
- pooled_prompt_embeds,
- negative_pooled_prompt_embeds,
- ip_adapter_image,
- ip_adapter_image_embeds,
- controlnet_conditioning_scale,
- control_guidance_start,
- control_guidance_end,
- callback_on_step_end_tensor_inputs,
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ control_guidance_start, control_guidance_end = (
+ mult * [control_guidance_start],
+ mult * [control_guidance_end],
)
- control_type = torch.Tensor(control_type)
+ if isinstance(controlnet_conditioning_scale, float):
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ control_image,
+ strength,
+ num_inference_steps,
+ callback_steps,
+ negative_prompt,
+ negative_prompt_2,
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ controlnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ control_mode,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
+ ]
self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
@@ -1334,7 +1385,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
device = self._execution_device
- global_pool_conditions = controlnet.config.global_pool_conditions
+ global_pool_conditions = (
+ controlnet.config.global_pool_conditions
+ if isinstance(controlnet, ControlNetUnionModel)
+ else controlnet.nets[0].config.global_pool_conditions
+ )
guess_mode = guess_mode or global_pool_conditions
# 3.1. Encode input prompt
@@ -1372,22 +1427,55 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
self.do_classifier_free_guidance,
)
- # 4. Prepare image and controlnet_conditioning_image
+ # 4.1 Prepare image
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
- for idx, _ in enumerate(control_image):
- control_image[idx] = self.prepare_control_image(
- image=control_image[idx],
- width=width,
- height=height,
- batch_size=batch_size * num_images_per_prompt,
- num_images_per_prompt=num_images_per_prompt,
- device=device,
- dtype=controlnet.dtype,
- do_classifier_free_guidance=self.do_classifier_free_guidance,
- guess_mode=guess_mode,
- )
- height, width = control_image[idx].shape[-2:]
+ # 4.2 Prepare control images
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_images = []
+
+ for image_ in control_image:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ control_images.append(image_)
+
+ control_image = control_images
+ height, width = control_image[0].shape[-2:]
+
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_images = []
+
+ for control_image_ in control_image:
+ images = []
+
+ for image_ in control_image_:
+ image_ = self.prepare_control_image(
+ image=image_,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=controlnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+
+ images.append(image_)
+ control_images.append(images)
+
+ control_image = control_images
+ height, width = control_image[0][0].shape[-2:]
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -1414,10 +1502,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
- controlnet_keep.append(
- 1.0
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
- )
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ controlnet_keep.append(keeps)
# 7.2 Prepare added time ids & embeddings
original_size = original_size or (height, width)
@@ -1460,12 +1549,25 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
prompt_embeds = prompt_embeds.to(device)
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device)
- control_type = (
- control_type.reshape(1, -1)
- .to(device, dtype=prompt_embeds.dtype)
- .repeat(batch_size * num_images_per_prompt * 2, 1)
+
+ control_type_repeat_factor = (
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
)
+ if isinstance(controlnet, ControlNetUnionModel):
+ control_type = (
+ control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ )
+ elif isinstance(controlnet, MultiControlNetUnionModel):
+ control_type = [
+ _control_type.reshape(1, -1)
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
+ .repeat(control_type_repeat_factor, 1)
+ for _control_type in control_type
+ ]
+
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
diff --git a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
index 525c5e90c5..59c79e134e 100644
--- a/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
+++ b/src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py
@@ -717,7 +717,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Sta
from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
- # make sure you're logged in with `huggingface-cli login`
+ # make sure you're logged in with `hf auth login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py
index 14f63ea229..68ff6c9b55 100644
--- a/src/diffusers/pipelines/dit/pipeline_dit.py
+++ b/src/diffusers/pipelines/dit/pipeline_dit.py
@@ -46,7 +46,9 @@ class DiTPipeline(DiffusionPipeline):
Parameters:
transformer ([`DiTTransformer2DModel`]):
- A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents.
+ A class conditioned `DiTTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/facebook/DiT-XL-2-256/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
scheduler ([`DDIMScheduler`]):
diff --git a/src/diffusers/pipelines/flux/__init__.py b/src/diffusers/pipelines/flux/__init__.py
index 72e1b578f2..ea25c148e2 100644
--- a/src/diffusers/pipelines/flux/__init__.py
+++ b/src/diffusers/pipelines/flux/__init__.py
@@ -33,6 +33,8 @@ else:
_import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
_import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
_import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
+ _import_structure["pipeline_flux_kontext"] = ["FluxKontextPipeline"]
+ _import_structure["pipeline_flux_kontext_inpaint"] = ["FluxKontextInpaintPipeline"]
_import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -52,6 +54,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_flux_fill import FluxFillPipeline
from .pipeline_flux_img2img import FluxImg2ImgPipeline
from .pipeline_flux_inpaint import FluxInpaintPipeline
+ from .pipeline_flux_kontext import FluxKontextPipeline
+ from .pipeline_flux_kontext_inpaint import FluxKontextInpaintPipeline
from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
else:
import sys
diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py
index bdee2ead48..7211fb5693 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux.py
@@ -674,7 +674,8 @@ class FluxPipeline(
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -687,11 +688,11 @@ class FluxPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -840,6 +841,8 @@ class FluxPipeline(
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
+ sigmas = None
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
@@ -898,6 +901,8 @@ class FluxPipeline(
)
# 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
@@ -910,32 +915,35 @@ class FluxPipeline(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latents,
- timestep=timestep / 1000,
- guidance=guidance,
- pooled_projections=pooled_prompt_embeds,
- encoder_hidden_states=prompt_embeds,
- txt_ids=text_ids,
- img_ids=latent_image_ids,
- joint_attention_kwargs=self.joint_attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- if negative_image_embeds is not None:
- self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
- pooled_projections=negative_pooled_prompt_embeds,
- encoder_hidden_states=negative_prompt_embeds,
- txt_ids=negative_text_ids,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latents,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_image_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py
index b4f77cf019..5a057f94cf 100644
--- a/src/diffusers/pipelines/flux/pipeline_flux_control.py
+++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py
@@ -163,9 +163,9 @@ class FluxControlPipeline(
TextualInversionLoaderMixin,
):
r"""
- The Flux pipeline for controllable text-to-image generation.
+ The Flux pipeline for controllable text-to-image generation with image conditions.
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+ Reference: https://bfl.ai/flux-1-tools
Args:
transformer ([`FluxTransformer2DModel`]):
@@ -661,11 +661,11 @@ class FluxControlPipeline(
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality.
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with prompt at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
new file mode 100644
index 0000000000..3c78aeaf36
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py
@@ -0,0 +1,1134 @@
+# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = load_image(
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yarn-art-pikachu.png"
+ ... ).convert("RGB")
+ >>> prompt = "Make Pikachu hold a sign that says 'Black Forest Labs is awesome', yarn art style, detailed, vibrant colors"
+ >>> image = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... guidance_scale=2.5,
+ ... generator=torch.Generator().manual_seed(42),
+ ... ).images[0]
+ >>> image.save("output.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for image-to-image and text-to-image generation.
+
+ Reference: https://bfl.ai/announcements/flux-1-kontext-dev
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ return latents, image_latents, latent_ids, image_ids
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with prompt at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ height,
+ width,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 3. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+ image = self.image_processor.preprocess(image, image_height, image_width)
+
+ # 4. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, latent_ids, image_ids = self.prepare_latents(
+ image,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+ if image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ # We set the index here to remove DtoH sync, helpful especially during compilation.
+ # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
+ self.scheduler.set_begin_index(0)
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..6dc621901c
--- /dev/null
+++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,1460 @@
+# Copyright 2025 ZenAI. All rights reserved.
+# author: @vuongminh1907
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import (
+ CLIPImageProcessor,
+ CLIPTextModel,
+ CLIPTokenizer,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+ T5TokenizerFast,
+)
+
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
+from ...models import AutoencoderKL, FluxTransformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ USE_PEFT_BACKEND,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import FluxPipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ # Inpainting with text only
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> prompt = "Change the yellow dinosaur to green one"
+ >>> img_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
+ ... )
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
+ >>> image.save("kontext_inpainting_normal.png")
+ ```
+
+ # Inpainting with image conditioning
+ ```py
+ >>> import torch
+ >>> from diffusers import FluxKontextInpaintPipeline
+ >>> from diffusers.utils import load_image
+
+ >>> pipe = FluxKontextInpaintPipeline.from_pretrained(
+ ... "black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
+ ... )
+ >>> pipe.to("cuda")
+
+ >>> prompt = "Replace this ball"
+ >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
+ >>> mask_url = (
+ ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
+ ... )
+ >>> image_reference_url = (
+ ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
+ ... )
+
+ >>> source = load_image(img_url)
+ >>> mask = load_image(mask_url)
+ >>> image_reference = load_image(image_reference_url)
+
+ >>> mask = pipe.mask_processor.blur(mask, blur_factor=12)
+ >>> image = pipe(
+ ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
+ ... ).images[0]
+ >>> image.save("kontext_inpainting_ref.png")
+ ```
+"""
+
+PREFERRED_KONTEXT_RESOLUTIONS = [
+ (672, 1568),
+ (688, 1504),
+ (720, 1456),
+ (752, 1392),
+ (800, 1328),
+ (832, 1248),
+ (880, 1184),
+ (944, 1104),
+ (1024, 1024),
+ (1104, 944),
+ (1184, 880),
+ (1248, 832),
+ (1328, 800),
+ (1392, 752),
+ (1456, 720),
+ (1504, 688),
+ (1568, 672),
+]
+
+
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.15,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class FluxKontextInpaintPipeline(
+ DiffusionPipeline,
+ FluxLoraLoaderMixin,
+ FromSingleFileMixin,
+ TextualInversionLoaderMixin,
+ FluxIPAdapterMixin,
+):
+ r"""
+ The Flux Kontext pipeline for text-to-image generation.
+
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
+
+ Args:
+ transformer ([`FluxTransformer2DModel`]):
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ text_encoder_2 ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
+ tokenizer_2 (`T5TokenizerFast`):
+ Second Tokenizer of class
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
+ """
+
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
+ _optional_components = ["image_encoder", "feature_extractor"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+
+ def __init__(
+ self,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ text_encoder_2: T5EncoderModel,
+ tokenizer_2: T5TokenizerFast,
+ transformer: FluxTransformer2DModel,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ text_encoder_2=text_encoder_2,
+ tokenizer=tokenizer,
+ tokenizer_2=tokenizer_2,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
+ self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ self.mask_processor = VaeImageProcessor(
+ vae_scale_factor=self.vae_scale_factor * 2,
+ vae_latent_channels=self.latent_channels,
+ do_normalize=False,
+ do_binarize=True,
+ do_convert_grayscale=True,
+ )
+
+ self.tokenizer_max_length = (
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
+ )
+ self.default_sample_size = 128
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
+
+ text_inputs = self.tokenizer_2(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_length=False,
+ return_overflowing_tokens=False,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because `max_sequence_length` is set to "
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
+
+ dtype = self.text_encoder_2.dtype
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
+ def _get_clip_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ batch_size = len(prompt)
+
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer_max_length,
+ truncation=True,
+ return_overflowing_tokens=False,
+ return_length=False,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
+ )
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
+
+ # Use pooled output of CLIPTextModel
+ prompt_embeds = prompt_embeds.pooler_output
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ prompt_2: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ num_images_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ max_sequence_length: int = 512,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ used in all text-encoders
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ device = device or self._execution_device
+
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder, lora_scale)
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
+ scale_lora_layers(self.text_encoder_2, lora_scale)
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+
+ if prompt_embeds is None:
+ prompt_2 = prompt_2 or prompt
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+ # We only use the pooled prompt output from the CLIPTextModel
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ )
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt_2,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ if self.text_encoder is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ if self.text_encoder_2 is not None:
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
+
+ return prompt_embeds, pooled_prompt_embeds, text_ids
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ return image_embeds
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
+ ):
+ image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_ip_adapter_image in ip_adapter_image:
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
+ image_embeds.append(single_image_embeds[None, :])
+ else:
+ if not isinstance(ip_adapter_image_embeds, list):
+ ip_adapter_image_embeds = [ip_adapter_image_embeds]
+
+ if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
+ )
+
+ for single_image_embeds in ip_adapter_image_embeds:
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for single_image_embeds in image_embeds:
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
+
+ t_start = int(max(num_inference_steps - init_timestep, 0))
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+ if hasattr(self.scheduler, "set_begin_index"):
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
+
+ return timesteps, num_inference_steps - t_start
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.check_inputs
+ def check_inputs(
+ self,
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type,
+ negative_prompt=None,
+ negative_prompt_2=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ pooled_prompt_embeds=None,
+ negative_pooled_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ padding_mask_crop=None,
+ max_sequence_length=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ logger.warning(
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt_2 is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
+ )
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
+ raise ValueError(
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
+ )
+
+ if padding_mask_crop is not None:
+ if not isinstance(image, PIL.Image.Image):
+ raise ValueError(
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
+ )
+ if not isinstance(mask_image, PIL.Image.Image):
+ raise ValueError(
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+ f" {type(mask_image)}."
+ )
+ if output_type != "pil":
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
+ latent_image_ids = torch.zeros(height, width, 3)
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
+
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
+
+ latent_image_ids = latent_image_ids.reshape(
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
+ )
+
+ return latent_image_ids.to(device=device, dtype=dtype)
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
+
+ return latents
+
+ @staticmethod
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
+ def _unpack_latents(latents, height, width, vae_scale_factor):
+ batch_size, num_patches, channels = latents.shape
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (vae_scale_factor * 2))
+ width = 2 * (int(width) // (vae_scale_factor * 2))
+
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
+
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
+
+ return latents
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i], sample_mode="argmax")
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
+
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
+
+ return image_latents
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(
+ self,
+ image: Optional[torch.Tensor],
+ timestep: int,
+ batch_size: int,
+ num_channels_latents: int,
+ height: int,
+ width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ image_reference: Optional[torch.Tensor] = None,
+ ):
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ shape = (batch_size, num_channels_latents, height, width)
+
+ # Prepare image latents
+ image_latents = image_ids = None
+ if image is not None:
+ image = image.to(device=device, dtype=dtype)
+ if image.shape[1] != self.latent_channels:
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+ else:
+ image_latents = image
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_latents = torch.cat([image_latents], dim=0)
+
+ # Prepare image reference latents
+ image_reference_latents = image_reference_ids = None
+ if image_reference is not None:
+ image_reference = image_reference.to(device=device, dtype=dtype)
+ if image_reference.shape[1] != self.latent_channels:
+ image_reference_latents = self._encode_vae_image(image=image_reference, generator=generator)
+ else:
+ image_reference_latents = image_reference
+ if batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] == 0:
+ # expand init_latents for batch_size
+ additional_image_per_prompt = batch_size // image_reference_latents.shape[0]
+ image_reference_latents = torch.cat([image_reference_latents] * additional_image_per_prompt, dim=0)
+ elif batch_size > image_reference_latents.shape[0] and batch_size % image_reference_latents.shape[0] != 0:
+ raise ValueError(
+ f"Cannot duplicate `image_reference` of batch size {image_reference_latents.shape[0]} to {batch_size} text prompts."
+ )
+ else:
+ image_reference_latents = torch.cat([image_reference_latents], dim=0)
+
+ latent_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
+ else:
+ noise = latents.to(device=device, dtype=dtype)
+ latents = noise
+
+ image_latent_height, image_latent_width = image_latents.shape[2:]
+ image_latents = self._pack_latents(
+ image_latents, batch_size, num_channels_latents, image_latent_height, image_latent_width
+ )
+ image_ids = self._prepare_latent_image_ids(
+ batch_size, image_latent_height // 2, image_latent_width // 2, device, dtype
+ )
+ # image ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_ids[..., 0] = 1
+
+ if image_reference_latents is not None:
+ image_reference_latent_height, image_reference_latent_width = image_reference_latents.shape[2:]
+ image_reference_latents = self._pack_latents(
+ image_reference_latents,
+ batch_size,
+ num_channels_latents,
+ image_reference_latent_height,
+ image_reference_latent_width,
+ )
+ image_reference_ids = self._prepare_latent_image_ids(
+ batch_size, image_reference_latent_height // 2, image_reference_latent_width // 2, device, dtype
+ )
+ # image_reference_ids are the same as latent ids with the first dimension set to 1 instead of 0
+ image_reference_ids[..., 0] = 1
+
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
+
+ return latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise
+
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ ):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
+ mask = mask.to(device=device, dtype=dtype)
+
+ batch_size = batch_size * num_images_per_prompt
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+
+ if masked_image.shape[1] == 16:
+ masked_image_latents = masked_image
+ else:
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
+
+ masked_image_latents = (
+ masked_image_latents - self.vae.config.shift_factor
+ ) * self.vae.config.scaling_factor
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ masked_image_latents = self._pack_latents(
+ masked_image_latents,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+ mask = self._pack_latents(
+ mask.repeat(1, num_channels_latents, 1, 1),
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ )
+
+ return mask, masked_image_latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def joint_attention_kwargs(self):
+ return self._joint_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: Optional[PipelineImageInput] = None,
+ image_reference: Optional[PipelineImageInput] = None,
+ mask_image: PipelineImageInput = None,
+ prompt: Union[str, List[str]] = None,
+ prompt_2: Optional[Union[str, List[str]]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
+ true_cfg_scale: float = 1.0,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ padding_mask_crop: Optional[int] = None,
+ num_inference_steps: int = 28,
+ sigmas: Optional[List[float]] = None,
+ guidance_scale: float = 3.5,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ max_area: int = 1024**2,
+ _auto_resize: bool = True,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be be inpainted (which parts of the image
+ to be masked out with `mask_image` and repainted according to `prompt` and `image_reference`). For both
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
+ latents as `image`, but if passing latents directly it is not encoded again.
+ image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point for the
+ masked area. For both numpy array and pytorch tensor, the expected value range is between `[0, 1]` If
+ it's a tensor or a list or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)` If it is
+ a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can
+ also accept image latents as `image`, but if passing latents directly it is not encoded again.
+ mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
+ `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
+ are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
+ single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
+ color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
+ H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
+ 1)`, or `(H, W)`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+ will be used instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
+ not greater than `1`).
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+ `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
+ true_cfg_scale (`float`, *optional*, defaults to 1.0):
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
+ the image is large and contain information irrelevant for inpainting, such as background.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 3.5):
+ Embedded guidance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifier-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*):
+ Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_ip_adapter_image:
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+ input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
+ joint_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int` defaults to 512):
+ Maximum sequence length to use with the `prompt`.
+ max_area (`int`, defaults to `1024 ** 2`):
+ The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
+ area while maintaining the aspect ratio.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
+ images.
+ """
+
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+
+ original_height, original_width = height, width
+ aspect_ratio = width / height
+ width = round((max_area * aspect_ratio) ** 0.5)
+ height = round((max_area / aspect_ratio) ** 0.5)
+
+ multiple_of = self.vae_scale_factor * 2
+ width = width // multiple_of * multiple_of
+ height = height // multiple_of * multiple_of
+
+ if height != original_height or width != original_width:
+ logger.warning(
+ f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ prompt_2,
+ image,
+ mask_image,
+ strength,
+ height,
+ width,
+ output_type=output_type,
+ negative_prompt=negative_prompt,
+ negative_prompt_2=negative_prompt_2,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ padding_mask_crop=padding_mask_crop,
+ max_sequence_length=max_sequence_length,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._joint_attention_kwargs = joint_attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ # 2. Preprocess image
+ if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
+ if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
+ image = torch.cat(image, dim=0)
+ img = image[0] if isinstance(image, list) else image
+ image_height, image_width = self.image_processor.get_default_height_width(img)
+ aspect_ratio = image_width / image_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_width, image_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_width = image_width // multiple_of * multiple_of
+ image_height = image_height // multiple_of * multiple_of
+ image = self.image_processor.resize(image, image_height, image_width)
+
+ # Choose the resolution of the image to be the same as the image
+ width = image_width
+ height = image_height
+
+ # 2.1 Preprocess mask
+ if padding_mask_crop is not None:
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+ resize_mode = "fill"
+ else:
+ crops_coords = None
+ resize_mode = "default"
+
+ image = self.image_processor.preprocess(
+ image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode
+ )
+ else:
+ raise ValueError("image must be provided correctly for inpainting")
+
+ init_image = image.to(dtype=torch.float32)
+
+ # 2.1 Preprocess image_reference
+ if image_reference is not None and not (
+ isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels
+ ):
+ if (
+ isinstance(image_reference, list)
+ and isinstance(image_reference[0], torch.Tensor)
+ and image_reference[0].ndim == 4
+ ):
+ image_reference = torch.cat(image_reference, dim=0)
+ img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference
+ image_reference_height, image_reference_width = self.image_processor.get_default_height_width(
+ img_reference
+ )
+ aspect_ratio = image_reference_width / image_reference_height
+ if _auto_resize:
+ # Kontext is trained on specific resolutions, using one of them is recommended
+ _, image_reference_width, image_reference_height = min(
+ (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
+ )
+ image_reference_width = image_reference_width // multiple_of * multiple_of
+ image_reference_height = image_reference_height // multiple_of * multiple_of
+ image_reference = self.image_processor.resize(
+ image_reference, image_reference_height, image_reference_width
+ )
+ image_reference = self.image_processor.preprocess(
+ image_reference,
+ image_reference_height,
+ image_reference_width,
+ crops_coords=crops_coords,
+ resize_mode=resize_mode,
+ )
+ else:
+ image_reference = None
+
+ # 3. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ lora_scale = (
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
+ )
+ has_neg_prompt = negative_prompt is not None or (
+ negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
+ )
+ do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
+ (
+ prompt_embeds,
+ pooled_prompt_embeds,
+ text_ids,
+ ) = self.encode_prompt(
+ prompt=prompt,
+ prompt_2=prompt_2,
+ prompt_embeds=prompt_embeds,
+ pooled_prompt_embeds=pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+ if do_true_cfg:
+ (
+ negative_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ negative_text_ids,
+ ) = self.encode_prompt(
+ prompt=negative_prompt,
+ prompt_2=negative_prompt_2,
+ prompt_embeds=negative_prompt_embeds,
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ max_sequence_length=max_sequence_length,
+ lora_scale=lora_scale,
+ )
+
+ # 4. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels // 4
+ latents, image_latents, image_reference_latents, latent_ids, image_ids, image_reference_ids, noise = (
+ self.prepare_latents(
+ init_image,
+ latent_timestep,
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image_reference,
+ )
+ )
+
+ if image_reference_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_reference_ids], dim=0) # dim 0 is sequence dimension
+ elif image_ids is not None:
+ latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
+
+ mask_condition = self.mask_processor.preprocess(
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+ )
+
+ masked_image = init_image * (mask_condition < 0.5)
+
+ mask, _ = self.prepare_mask_latents(
+ mask_condition,
+ masked_image,
+ batch_size,
+ num_channels_latents,
+ num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ )
+
+ # handle guidance
+ if self.transformer.config.guidance_embeds:
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
+ guidance = guidance.expand(latents.shape[0])
+ else:
+ guidance = None
+
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
+ ):
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
+ ):
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
+ ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
+
+ if self.joint_attention_kwargs is None:
+ self._joint_attention_kwargs = {}
+
+ image_embeds = None
+ negative_image_embeds = None
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
+ negative_ip_adapter_image,
+ negative_ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ )
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ if image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
+
+ latent_model_input = latents
+ if image_reference_latents is not None:
+ latent_model_input = torch.cat([latents, image_reference_latents], dim=1)
+ elif image_latents is not None:
+ latent_model_input = torch.cat([latents, image_latents], dim=1)
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ txt_ids=text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred[:, : latents.size(1)]
+
+ if do_true_cfg:
+ if negative_image_embeds is not None:
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep / 1000,
+ guidance=guidance,
+ pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=negative_prompt_embeds,
+ txt_ids=negative_text_ids,
+ img_ids=latent_ids,
+ joint_attention_kwargs=self.joint_attention_kwargs,
+ return_dict=False,
+ )[0]
+ neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ init_latents_proper = image_latents
+ init_mask = mask
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.scale_noise(
+ init_latents_proper, torch.tensor([noise_timestep]), noise
+ )
+
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if output_type == "latent":
+ image = latents
+ else:
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return FluxPipelineOutput(images=image)
diff --git a/src/diffusers/pipelines/flux/pipeline_output.py b/src/diffusers/pipelines/flux/pipeline_output.py
index 388824e89f..69e742d3e0 100644
--- a/src/diffusers/pipelines/flux/pipeline_output.py
+++ b/src/diffusers/pipelines/flux/pipeline_output.py
@@ -11,12 +11,14 @@ from ...utils import BaseOutput
@dataclass
class FluxPipelineOutput(BaseOutput):
"""
- Output class for Stable Diffusion pipelines.
+ Output class for Flux image generation pipelines.
Args:
- images (`List[PIL.Image.Image]` or `np.ndarray`)
- List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
- num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
+ images (`List[PIL.Image.Image]` or `torch.Tensor` or `np.ndarray`)
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
+ height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion
+ pipeline. Torch tensors can represent either the denoised images or the intermediate latents ready to be
+ passed to the decoder.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
index 341cdaf1e6..695f54f3d9 100644
--- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
+++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
@@ -763,11 +763,11 @@ class HiDreamImagePipeline(DiffusionPipeline, HiDreamImageLoraLoaderMixin):
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
index b617e4f8b2..76b288ed0b 100644
--- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
+++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py
@@ -529,15 +529,14 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
- When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
+ True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and
+ `negative_prompt` is provided.
guidance_scale (`float`, defaults to `6.0`):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality. Note that the only available
- HunyuanVideo model is CFG-distilled, which means that traditional guidance between unconditional and
- conditional latent is not applied.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -693,28 +692,30 @@ class HunyuanVideoPipeline(DiffusionPipeline, HunyuanVideoLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_attention_mask=prompt_attention_mask,
- pooled_projections=pooled_prompt_embeds,
- guidance=guidance,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if do_true_cfg:
- neg_noise_pred = self.transformer(
+ with self.transformer.cache_context("cond"):
+ noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
- encoder_attention_mask=negative_prompt_attention_mask,
- pooled_projections=negative_pooled_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
+ encoder_attention_mask=prompt_attention_mask,
+ pooled_projections=pooled_prompt_embeds,
guidance=guidance,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
+
+ if do_true_cfg:
+ with self.transformer.cache_context("uncond"):
+ neg_noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_attention_mask=negative_prompt_attention_mask,
+ pooled_projections=negative_pooled_prompt_embeds,
+ guidance=guidance,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx.py b/src/diffusers/pipelines/ltx/pipeline_ltx.py
index 3b58b4a45a..77ba751700 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx.py
@@ -757,18 +757,19 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
index fa9ee4fc7b..217478f418 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_condition.py
@@ -1177,15 +1177,16 @@ class LTXConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraL
if is_conditioning_image_or_video:
timestep = torch.min(timestep, (1 - conditioning_mask_model_input) * 1000.0)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- video_coords=video_coords,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ video_coords=video_coords,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
diff --git a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
index 99412b6962..8793d81377 100644
--- a/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
+++ b/src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
@@ -830,18 +830,19 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
timestep = t.expand(latent_model_input.shape[0])
timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- num_frames=latent_num_frames,
- height=latent_height,
- width=latent_width,
- rope_interpolation_scale=rope_interpolation_scale,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ num_frames=latent_num_frames,
+ height=latent_height,
+ width=latent_width,
+ rope_interpolation_scale=rope_interpolation_scale,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
noise_pred = noise_pred.float()
if self.do_classifier_free_guidance:
diff --git a/src/diffusers/pipelines/mochi/pipeline_mochi.py b/src/diffusers/pipelines/mochi/pipeline_mochi.py
index 7712b41524..3c0f908296 100644
--- a/src/diffusers/pipelines/mochi/pipeline_mochi.py
+++ b/src/diffusers/pipelines/mochi/pipeline_mochi.py
@@ -671,14 +671,15 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- encoder_hidden_states=prompt_embeds,
- timestep=timestep,
- encoder_attention_mask=prompt_attention_mask,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ with self.transformer.cache_context("cond_uncond"):
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ encoder_hidden_states=prompt_embeds,
+ timestep=timestep,
+ encoder_attention_mask=prompt_attention_mask,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
index f6906074b3..1254b6725f 100644
--- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py
@@ -23,12 +23,14 @@ from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import OmniGenTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
-from ...utils import is_torch_xla_available, logging, replace_example_docstring
+from ...utils import is_torch_xla_available, is_torchvision_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
-from .processor_omnigen import OmniGenMultiModalProcessor
+if is_torchvision_available():
+ from .processor_omnigen import OmniGenMultiModalProcessor
+
if is_torch_xla_available():
XLA_AVAILABLE = True
else:
diff --git a/src/diffusers/pipelines/omnigen/processor_omnigen.py b/src/diffusers/pipelines/omnigen/processor_omnigen.py
index be5ff82c4a..7ed11871bb 100644
--- a/src/diffusers/pipelines/omnigen/processor_omnigen.py
+++ b/src/diffusers/pipelines/omnigen/processor_omnigen.py
@@ -18,7 +18,12 @@ from typing import Dict, List
import numpy as np
import torch
from PIL import Image
-from torchvision import transforms
+
+from ...utils import is_torchvision_available
+
+
+if is_torchvision_available():
+ from torchvision import transforms
def crop_image(pil_image, max_image_size):
diff --git a/src/diffusers/pipelines/pipeline_flax_utils.py b/src/diffusers/pipelines/pipeline_flax_utils.py
index 7c5ac89602..ea2c0763d9 100644
--- a/src/diffusers/pipelines/pipeline_flax_utils.py
+++ b/src/diffusers/pipelines/pipeline_flax_utils.py
@@ -278,8 +278,8 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`.
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login`.
diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py
index d1c2c2adb4..b5ac6cc301 100644
--- a/src/diffusers/pipelines/pipeline_loading_utils.py
+++ b/src/diffusers/pipelines/pipeline_loading_utils.py
@@ -371,6 +371,22 @@ def maybe_raise_or_warn(
)
+# a simpler version of get_class_obj_and_candidates, it won't work with custom code
+def simple_get_class_obj(library_name, class_name):
+ from diffusers import pipelines
+
+ is_pipeline_module = hasattr(pipelines, library_name)
+
+ if is_pipeline_module:
+ pipeline_module = getattr(pipelines, library_name)
+ class_obj = getattr(pipeline_module, class_name)
+ else:
+ library = importlib.import_module(library_name)
+ class_obj = getattr(library, class_name)
+
+ return class_obj
+
+
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
@@ -452,7 +468,7 @@ def _get_pipeline_class(
revision=revision,
)
- if class_obj.__name__ != "DiffusionPipeline":
+ if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
@@ -892,7 +908,10 @@ def _fetch_class_library_tuple(module):
library = not_compiled_module.__module__
# retrieve class_name
- class_name = not_compiled_module.__class__.__name__
+ if isinstance(not_compiled_module, type):
+ class_name = not_compiled_module.__name__
+ else:
+ class_name = not_compiled_module.__class__.__name__
return (library, class_name)
diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py
index 2c03811b51..22efaccec1 100644
--- a/src/diffusers/pipelines/pipeline_utils.py
+++ b/src/diffusers/pipelines/pipeline_utils.py
@@ -710,8 +710,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with
- `huggingface-cli login`.
+ To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
+ auth login`.
@@ -1096,6 +1096,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
if device_map is not None:
setattr(model, "hf_device_map", final_device_map)
+ if quantization_config is not None:
+ setattr(model, "quantization_config", quantization_config)
return model
@property
@@ -1428,8 +1430,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`.
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login
@@ -1986,11 +1988,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
}
+ optional_components = (
+ pipeline._optional_components
+ if hasattr(pipeline, "_optional_components") and pipeline._optional_components
+ else []
+ )
missing_modules = (
- set(expected_modules)
- - set(pipeline._optional_components)
- - set(pipeline_kwargs.keys())
- - set(true_optional_modules)
+ set(expected_modules) - set(optional_components) - set(pipeline_kwargs.keys()) - set(true_optional_modules)
)
if len(missing_modules) > 0:
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
index f7e70c511b..bd69746be3 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py
@@ -256,7 +256,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`PixArtTransformer2DModel`]):
- A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
+ A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
index c3d235d91b..c14036cf94 100644
--- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
+++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py
@@ -185,6 +185,26 @@ def retrieve_timesteps(
class PixArtSigmaPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Sigma.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`T5EncoderModel`]):
+ Frozen text-encoder. PixArt-Alpha uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`T5Tokenizer`):
+ Tokenizer of class
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
+ transformer ([`PixArtTransformer2DModel`]):
+ A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents. Initially published as
+ [`Transformer2DModel`](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS/blob/main/transformer/config.json#L2)
+ in the config, but the mismatch can be ignored.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
index fcf854a54c..e8f9d8368f 100644
--- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
+++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py
@@ -643,11 +643,11 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.5):
- Guidance scale as defined in [Classifier-Free Diffusion
- Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
- of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
- `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
- the text `prompt`, usually at the expense of lower image quality.
+ Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages
+ a model to generate images more aligned with `prompt` at the expense of lower image quality.
+
+ Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to
+ the [paper](https://huggingface.co/papers/2210.03142) to learn more.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
diff --git a/src/diffusers/pipelines/skyreels_v2/__init__.py b/src/diffusers/pipelines/skyreels_v2/__init__.py
new file mode 100644
index 0000000000..84d2a2dd35
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/__init__.py
@@ -0,0 +1,59 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_skyreels_v2"] = ["SkyReelsV2Pipeline"]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing"] = ["SkyReelsV2DiffusionForcingPipeline"]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing_i2v"] = [
+ "SkyReelsV2DiffusionForcingImageToVideoPipeline"
+ ]
+ _import_structure["pipeline_skyreels_v2_diffusion_forcing_v2v"] = [
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline"
+ ]
+ _import_structure["pipeline_skyreels_v2_i2v"] = ["SkyReelsV2ImageToVideoPipeline"]
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_skyreels_v2 import SkyReelsV2Pipeline
+ from .pipeline_skyreels_v2_diffusion_forcing import SkyReelsV2DiffusionForcingPipeline
+ from .pipeline_skyreels_v2_diffusion_forcing_i2v import SkyReelsV2DiffusionForcingImageToVideoPipeline
+ from .pipeline_skyreels_v2_diffusion_forcing_v2v import SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ from .pipeline_skyreels_v2_i2v import SkyReelsV2ImageToVideoPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_output.py b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py
new file mode 100644
index 0000000000..7a170d24c3
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_output.py
@@ -0,0 +1,20 @@
+from dataclasses import dataclass
+
+import torch
+
+from diffusers.utils import BaseOutput
+
+
+@dataclass
+class SkyReelsV2PipelineOutput(BaseOutput):
+ r"""
+ Output class for SkyReelsV2 pipelines.
+
+ Args:
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
+ `(batch_size, num_frames, channels, height, width)`.
+ """
+
+ frames: torch.Tensor
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
new file mode 100644
index 0000000000..8562a5eaf0
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py
@@ -0,0 +1,610 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import regex as re
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2Pipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-T2V-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2Pipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ r"""
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ int(height) // self.vae_scale_factor_spatial,
+ int(width) // self.vae_scale_factor_spatial,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, defaults to `544`):
+ The height in pixels of the generated image.
+ width (`int`, defaults to `960`):
+ The width in pixels of the generated image.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length for the text encoder.
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = latents.to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
new file mode 100644
index 0000000000..d0a4e118ce
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py
@@ -0,0 +1,978 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=30,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=5, # Controls asynchronous inference (0 for synchronous mode)
+ ... causal_block_size=5, # Number of frames processed together in a causal block
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Text-to-Video (t2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ prefix_video_latents = None
+ prefix_video_latents_frames = 0
+
+ if video_latents is not None: # long video generation at the iterations other than the first one
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
+ + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
+ num_latent_frames = base_latent_num_frames
+ else: # short video generation
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
+
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Determine if we're doing long video generation
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
+ # Initialize accumulated_latents to store all latents in one tensor
+ accumulated_latents = None
+ if is_long_video:
+ # Long video generation setup
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ else:
+ # Short video generation setup
+ n_iter = 1
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ # Loop through iterations (multiple iterations only for long videos)
+ for iter_idx in range(n_iter):
+ if is_long_video:
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
+ self.prepare_latents(
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if iter_idx == 0 else None,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
+ long_video_iter=iter_idx if is_long_video else None,
+ )
+ )
+
+ if prefix_video_latents_frames > 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
+
+ # 6. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+
+ # Different matrix generation for short vs long video
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ current_num_latent_frames if is_long_video else base_latent_num_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Handle latent accumulation for long videos or use the current latents for short videos
+ if is_long_video:
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
+ )
+
+ if is_long_video:
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
new file mode 100644
index 0000000000..959cbb32f2
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py
@@ -0,0 +1,1059 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import ftfy
+import PIL
+import torch
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from diffusers.image_processor import PipelineImageInput
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.video_processor import VideoProcessor
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+ >>> from PIL import Image
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingImageToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> image = Image.open("path/to/image.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `image_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ image: Optional[PipelineImageInput],
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ prefix_video_latents_frames = 0
+
+ if video_latents is not None: # long video generation at the iterations other than the first one
+ condition = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if condition.shape[2] % causal_block_size != 0:
+ truncate_len_latents = condition.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ condition = condition[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = condition.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames)
+ + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+ elif base_latent_num_frames is not None: # long video generation at the first iteration
+ num_latent_frames = base_latent_num_frames
+ else: # short video generation
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ if image is not None:
+ image = image.unsqueeze(2)
+ if last_image is not None:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat([image, last_image], dim=0)
+ else:
+ video_condition = image
+
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat_interleave(batch_size, dim=0)
+
+ latent_condition = latent_condition.to(dtype)
+ condition = (latent_condition - latents_mean) * latents_std
+ prefix_video_latents_frames = condition.shape[2]
+
+ return latents, num_latent_frames, condition, prefix_video_latents_frames
+
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ last_image (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Determine if we're doing long video generation
+ is_long_video = overlap_history is not None and base_num_frames is not None and num_frames > base_num_frames
+ # Initialize accumulated_latents to store all latents in one tensor
+ accumulated_latents = None
+ if is_long_video:
+ # Long video generation setup
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ else:
+ # Short video generation setup
+ n_iter = 1
+ base_latent_num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ # Loop through iterations (multiple iterations only for long videos)
+ for iter_idx in range(n_iter):
+ if is_long_video:
+ logger.debug(f"Processing iteration {iter_idx + 1}/{n_iter} for long video generation...")
+
+ num_channels_latents = self.vae.config.z_dim
+ latents, current_num_latent_frames, condition, prefix_video_latents_frames = self.prepare_latents(
+ image if iter_idx == 0 else None,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if iter_idx == 0 else None,
+ last_image,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ base_latent_num_frames=base_latent_num_frames if is_long_video else None,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames if is_long_video else None,
+ long_video_iter=iter_idx if is_long_video else None,
+ )
+
+ if iter_idx == 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition[: (condition.shape[0] + 1) // 2].to(
+ transformer_dtype
+ )
+ else:
+ latents[:, :, :prefix_video_latents_frames, :, :] = condition.to(transformer_dtype)
+
+ if iter_idx == 0 and last_image is not None:
+ end_video_latents = condition[condition.shape[0] // 2 :].to(transformer_dtype)
+
+ if last_image is not None and iter_idx + 1 == n_iter:
+ latents = torch.cat([latents, end_video_latents], dim=2)
+ base_latent_num_frames += prefix_video_latents_frames
+ current_num_latent_frames += prefix_video_latents_frames
+
+ # 4. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ base_latent_num_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ if last_image is not None and iter_idx + 1 == n_iter:
+ step_matrix[:, -prefix_video_latents_frames:] = 0
+ step_update_mask[:, -prefix_video_latents_frames:] = False
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ # Handle latent accumulation for long videos or use the current latents for short videos
+ if is_long_video:
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]],
+ dim=2,
+ )
+
+ if is_long_video:
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ if last_image is not None:
+ latents = latents[:, :, :-prefix_video_latents_frames, :, :].to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
new file mode 100644
index 0000000000..6fedfc795a
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py
@@ -0,0 +1,1063 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+import inspect
+import math
+import re
+from copy import deepcopy
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import ftfy
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-DF-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-DF-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2DiffusionForcingVideoToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-DF-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+
+ >>> output = pipe(
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=6.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ar_step=0, # Controls asynchronous inference (0 for synchronous mode)
+ ... overlap_history=None, # Number of frames to overlap for smooth transitions in long videos
+ ... addnoise_condition=20, # Improves consistency in long video generation
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ """
+ Pipeline for Video-to-Video (v2v) generation using SkyReels-V2 with diffusion forcing.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a specific device, etc.).
+
+ Args:
+ tokenizer ([`AutoTokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`UMT5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the encoded image latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video=None,
+ latents=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ overlap_history=None,
+ num_frames=None,
+ base_num_frames=None,
+ ):
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ if video is not None and latents is not None:
+ raise ValueError("Only one of `video` or `latents` should be provided")
+
+ if num_frames > base_num_frames and overlap_history is None:
+ raise ValueError(
+ "`overlap_history` is required when `num_frames` exceeds `base_num_frames` to ensure smooth transitions in long video generation. "
+ "Please specify a value for `overlap_history`. Recommended values are 17 or 37."
+ )
+
+ def prepare_latents(
+ self,
+ video: torch.Tensor,
+ batch_size: int = 1,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 97,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ video_latents: Optional[torch.Tensor] = None,
+ base_latent_num_frames: Optional[int] = None,
+ overlap_history: Optional[int] = None,
+ causal_block_size: Optional[int] = None,
+ overlap_history_latent_frames: Optional[int] = None,
+ long_video_iter: Optional[int] = None,
+ ) -> torch.Tensor:
+ if latents is not None:
+ return latents.to(device=device, dtype=dtype)
+
+ num_latent_frames = (
+ (num_frames - 1) // self.vae_scale_factor_temporal + 1 if latents is None else latents.shape[2]
+ )
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ if long_video_iter == 0:
+ prefix_video_latents = [
+ retrieve_latents(
+ self.vae.encode(
+ vid.unsqueeze(0)[:, :, -overlap_history:] if vid.dim() == 4 else vid[:, :, -overlap_history:]
+ ),
+ sample_mode="argmax",
+ )
+ for vid in video
+ ]
+ prefix_video_latents = torch.cat(prefix_video_latents, dim=0).to(dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(device, self.vae.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ device, self.vae.dtype
+ )
+ prefix_video_latents = (prefix_video_latents - latents_mean) * latents_std
+ else:
+ prefix_video_latents = video_latents[:, :, -overlap_history_latent_frames:]
+
+ if prefix_video_latents.shape[2] % causal_block_size != 0:
+ truncate_len_latents = prefix_video_latents.shape[2] % causal_block_size
+ logger.warning(
+ f"The length of prefix video latents is truncated by {truncate_len_latents} frames for the causal block size alignment. "
+ f"This truncation ensures compatibility with the causal block size, which is required for proper processing. "
+ f"However, it may slightly affect the continuity of the generated video at the truncation boundary."
+ )
+ prefix_video_latents = prefix_video_latents[:, :, :-truncate_len_latents]
+ prefix_video_latents_frames = prefix_video_latents.shape[2]
+
+ finished_frame_num = (
+ long_video_iter * (base_latent_num_frames - overlap_history_latent_frames) + overlap_history_latent_frames
+ )
+ left_frame_num = num_latent_frames - finished_frame_num
+ num_latent_frames = min(left_frame_num + overlap_history_latent_frames, base_latent_num_frames)
+
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_latent_frames,
+ latent_height,
+ latent_width,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+
+ return latents, num_latent_frames, prefix_video_latents, prefix_video_latents_frames
+
+ # Copied from diffusers.pipelines.skyreels_v2.pipeline_skyreels_v2_diffusion_forcing.SkyReelsV2DiffusionForcingPipeline.generate_timestep_matrix
+ def generate_timestep_matrix(
+ self,
+ num_latent_frames: int,
+ step_template: torch.Tensor,
+ base_num_latent_frames: int,
+ ar_step: int = 5,
+ num_pre_ready: int = 0,
+ causal_block_size: int = 1,
+ shrink_interval_with_mask: bool = False,
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
+ """
+ This function implements the core diffusion forcing algorithm that creates a coordinated denoising schedule
+ across temporal frames. It supports both synchronous and asynchronous generation modes:
+
+ **Synchronous Mode** (ar_step=0, causal_block_size=1):
+ - All frames are denoised simultaneously at each timestep
+ - Each frame follows the same denoising trajectory: [1000, 800, 600, ..., 0]
+ - Simpler but may have less temporal consistency for long videos
+
+ **Asynchronous Mode** (ar_step>0, causal_block_size>1):
+ - Frames are grouped into causal blocks and processed block/chunk-wise
+ - Each block is denoised in a staggered pattern creating a "denoising wave"
+ - Earlier blocks are more denoised, later blocks lag behind by ar_step timesteps
+ - Creates stronger temporal dependencies and better consistency
+
+ Args:
+ num_latent_frames (int): Total number of latent frames to generate
+ step_template (torch.Tensor): Base timestep schedule (e.g., [1000, 800, 600, ..., 0])
+ base_num_latent_frames (int): Maximum frames the model can process in one forward pass
+ ar_step (int, optional): Autoregressive step size for temporal lag.
+ 0 = synchronous, >0 = asynchronous. Defaults to 5.
+ num_pre_ready (int, optional):
+ Number of frames already denoised (e.g., from prefix in a video2video task).
+ Defaults to 0.
+ causal_block_size (int, optional): Number of frames processed as a causal block.
+ Defaults to 1.
+ shrink_interval_with_mask (bool, optional): Whether to optimize processing intervals.
+ Defaults to False.
+
+ Returns:
+ tuple containing:
+ - step_matrix (torch.Tensor): Matrix of timesteps for each frame at each iteration Shape:
+ [num_iterations, num_latent_frames]
+ - step_index (torch.Tensor): Index matrix for timestep lookup Shape: [num_iterations,
+ num_latent_frames]
+ - step_update_mask (torch.Tensor): Boolean mask indicating which frames to update Shape:
+ [num_iterations, num_latent_frames]
+ - valid_interval (list[tuple]): List of (start, end) intervals for each iteration
+
+ Raises:
+ ValueError: If ar_step is too small for the given configuration
+ """
+ # Initialize lists to store the scheduling matrices and metadata
+ step_matrix, step_index = [], [] # Will store timestep values and indices for each iteration
+ update_mask, valid_interval = [], [] # Will store update masks and processing intervals
+
+ # Calculate total number of denoising iterations (add 1 for initial noise state)
+ num_iterations = len(step_template) + 1
+
+ # Convert frame counts to block counts for causal processing
+ # Each block contains causal_block_size frames that are processed together
+ # E.g.: 25 frames ÷ 5 = 5 blocks total
+ num_blocks = num_latent_frames // causal_block_size
+ base_num_blocks = base_num_latent_frames // causal_block_size
+
+ # Validate ar_step is sufficient for the given configuration
+ # In asynchronous mode, we need enough timesteps to create the staggered pattern
+ if base_num_blocks < num_blocks:
+ min_ar_step = len(step_template) / base_num_blocks
+ if ar_step < min_ar_step:
+ raise ValueError(f"`ar_step` should be at least {math.ceil(min_ar_step)} in your setting")
+
+ # Extend step_template with boundary values for easier indexing
+ # 999: dummy value for counter starting from 1
+ # 0: final timestep (completely denoised)
+ step_template = torch.cat(
+ [
+ torch.tensor([999], dtype=torch.int64, device=step_template.device),
+ step_template.long(),
+ torch.tensor([0], dtype=torch.int64, device=step_template.device),
+ ]
+ )
+
+ # Initialize the previous row state (tracks denoising progress for each block)
+ # 0 means not started, num_iterations means fully denoised
+ pre_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Mark pre-ready frames (e.g., from prefix video for a video2video task) as already at final denoising state
+ if num_pre_ready > 0:
+ pre_row[: num_pre_ready // causal_block_size] = num_iterations
+
+ # Main loop: Generate denoising schedule until all frames are fully denoised
+ while not torch.all(pre_row >= (num_iterations - 1)):
+ # Create new row representing the next denoising step
+ new_row = torch.zeros(num_blocks, dtype=torch.long)
+
+ # Apply diffusion forcing logic for each block
+ for i in range(num_blocks):
+ if i == 0 or pre_row[i - 1] >= (
+ num_iterations - 1
+ ): # the first frame or the last frame is completely denoised
+ new_row[i] = pre_row[i] + 1
+ else:
+ # Asynchronous mode: lag behind previous block by ar_step timesteps
+ # This creates the "diffusion forcing" staggered pattern
+ new_row[i] = new_row[i - 1] - ar_step
+
+ # Clamp values to valid range [0, num_iterations]
+ new_row = new_row.clamp(0, num_iterations)
+
+ # Create update mask: True for blocks that need denoising update at this iteration
+ # Exclude blocks that haven't started (new_row != pre_row) or are finished (new_row != num_iterations)
+ # Final state example: [False, ..., False, True, True, True, True, True]
+ # where first 20 frames are done (False) and last 5 frames still need updates (True)
+ update_mask.append((new_row != pre_row) & (new_row != num_iterations))
+
+ # Store the iteration state
+ step_index.append(new_row) # Index into step_template
+ step_matrix.append(step_template[new_row]) # Actual timestep values
+ pre_row = new_row # Update for next iteration
+
+ # For videos longer than model capacity, we process in sliding windows
+ terminal_flag = base_num_blocks
+
+ # Optional optimization: shrink interval based on first update mask
+ if shrink_interval_with_mask:
+ idx_sequence = torch.arange(num_blocks, dtype=torch.int64)
+ update_mask = update_mask[0]
+ update_mask_idx = idx_sequence[update_mask]
+ last_update_idx = update_mask_idx[-1].item()
+ terminal_flag = last_update_idx + 1
+
+ # Each interval defines which frames to process in the current forward pass
+ for curr_mask in update_mask:
+ # Extend terminal flag if current mask has updates beyond current terminal
+ if terminal_flag < num_blocks and curr_mask[terminal_flag]:
+ terminal_flag += 1
+ # Create interval: [start, end) where start ensures we don't exceed model capacity
+ valid_interval.append((max(terminal_flag - base_num_blocks, 0), terminal_flag))
+
+ # Convert lists to tensors for efficient processing
+ step_update_mask = torch.stack(update_mask, dim=0)
+ step_index = torch.stack(step_index, dim=0)
+ step_matrix = torch.stack(step_matrix, dim=0)
+
+ # Each block's schedule is replicated to all frames within that block
+ if causal_block_size > 1:
+ # Expand each block to causal_block_size frames
+ step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_index = step_index.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, causal_block_size).flatten(1).contiguous()
+ # Scale intervals from block-level to frame-level
+ valid_interval = [(s * causal_block_size, e * causal_block_size) for s, e in valid_interval]
+
+ return step_matrix, step_index, step_update_mask, valid_interval
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1.0
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ video: List[Image.Image],
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 120,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 6.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ overlap_history: Optional[int] = None,
+ addnoise_condition: float = 0,
+ base_num_frames: int = 97,
+ ar_step: int = 0,
+ causal_block_size: Optional[int] = None,
+ fps: int = 24,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ video (`List[Image.Image]`):
+ The video to guide the video generation.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `120`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `6.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality. (**6.0 for T2V**, **5.0 for I2V**)
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`SkyReelsV2PipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+ overlap_history (`int`, *optional*, defaults to `None`):
+ Number of frames to overlap for smooth transitions in long videos. If `None`, the pipeline assumes
+ short video generation mode, and no overlap is applied. 17 and 37 are recommended to set.
+ addnoise_condition (`float`, *optional*, defaults to `0`):
+ This is used to help smooth the long video generation by adding some noise to the clean condition. Too
+ large noise can cause the inconsistency as well. 20 is a recommended value, and you may try larger
+ ones, but it is recommended to not exceed 50.
+ base_num_frames (`int`, *optional*, defaults to `97`):
+ 97 or 121 | Base frame count (**97 for 540P**, **121 for 720P**)
+ ar_step (`int`, *optional*, defaults to `0`):
+ Controls asynchronous inference (0 for synchronous mode) You can set `ar_step=5` to enable asynchronous
+ inference. When asynchronous inference, `causal_block_size=5` is recommended while it is not supposed
+ to be set for synchronous generation. Asynchronous inference will take more steps to diffuse the whole
+ sequence which means it will be SLOWER than synchronous mode. In our experiments, asynchronous
+ inference may improve the instruction following and visual consistent performance.
+ causal_block_size (`int`, *optional*, defaults to `None`):
+ The number of frames in each block/chunk. Recommended when using asynchronous inference (when ar_step >
+ 0)
+ fps (`int`, *optional*, defaults to `24`):
+ Frame rate of the generated video
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
+ num_videos_per_prompt = 1
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ height,
+ width,
+ video,
+ latents,
+ prompt_embeds,
+ negative_prompt_embeds,
+ callback_on_step_end_tensor_inputs,
+ overlap_history,
+ num_frames,
+ base_num_frames,
+ )
+
+ if addnoise_condition > 60:
+ logger.warning(
+ f"The value of 'addnoise_condition' is too large ({addnoise_condition}) and may cause inconsistencies in long video generation. A value of 20 is recommended."
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ if latents is None:
+ video_original = self.video_processor.preprocess_video(video, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+
+ if causal_block_size is None:
+ causal_block_size = self.transformer.config.num_frame_per_block
+ else:
+ self.transformer._set_ar_attention(causal_block_size)
+
+ fps_embeds = [fps] * prompt_embeds.shape[0]
+ fps_embeds = [0 if i == 16 else 1 for i in fps_embeds]
+
+ # Long video generation
+ accumulated_latents = None
+ overlap_history_latent_frames = (overlap_history - 1) // self.vae_scale_factor_temporal + 1
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ base_latent_num_frames = (
+ (base_num_frames - 1) // self.vae_scale_factor_temporal + 1
+ if base_num_frames is not None
+ else num_latent_frames
+ )
+ n_iter = (
+ 1
+ + (num_latent_frames - base_latent_num_frames - 1)
+ // (base_latent_num_frames - overlap_history_latent_frames)
+ + 1
+ )
+ for long_video_iter in range(n_iter):
+ logger.debug(f"Processing iteration {long_video_iter + 1}/{n_iter} for long video generation...")
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.transformer.config.in_channels
+ latents, current_num_latent_frames, prefix_video_latents, prefix_video_latents_frames = (
+ self.prepare_latents(
+ video_original,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents if long_video_iter == 0 else None,
+ video_latents=accumulated_latents, # Pass latents directly instead of decoded video
+ overlap_history=overlap_history,
+ base_latent_num_frames=base_latent_num_frames,
+ causal_block_size=causal_block_size,
+ overlap_history_latent_frames=overlap_history_latent_frames,
+ long_video_iter=long_video_iter,
+ )
+ )
+
+ if prefix_video_latents_frames > 0:
+ latents[:, :, :prefix_video_latents_frames, :, :] = prefix_video_latents.to(transformer_dtype)
+
+ # 4. Prepare sample schedulers and timestep matrix
+ sample_schedulers = []
+ for _ in range(current_num_latent_frames):
+ sample_scheduler = deepcopy(self.scheduler)
+ sample_scheduler.set_timesteps(num_inference_steps, device=device)
+ sample_schedulers.append(sample_scheduler)
+ step_matrix, _, step_update_mask, valid_interval = self.generate_timestep_matrix(
+ current_num_latent_frames,
+ timesteps,
+ current_num_latent_frames,
+ ar_step,
+ prefix_video_latents_frames,
+ causal_block_size,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(step_matrix)
+
+ with self.progress_bar(total=len(step_matrix)) as progress_bar:
+ for i, t in enumerate(step_matrix):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ valid_interval_start, valid_interval_end = valid_interval[i]
+ latent_model_input = (
+ latents[:, :, valid_interval_start:valid_interval_end, :, :].to(transformer_dtype).clone()
+ )
+ timestep = t.expand(latents.shape[0], -1)[:, valid_interval_start:valid_interval_end].clone()
+
+ if addnoise_condition > 0 and valid_interval_start < prefix_video_latents_frames:
+ noise_factor = 0.001 * addnoise_condition
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :] = (
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ * (1.0 - noise_factor)
+ + torch.randn_like(
+ latent_model_input[:, :, valid_interval_start:prefix_video_latents_frames, :, :]
+ )
+ * noise_factor
+ )
+ timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ enable_diffusion_forcing=True,
+ fps=fps_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ update_mask_i = step_update_mask[i]
+ for idx in range(valid_interval_start, valid_interval_end):
+ if update_mask_i[idx].item():
+ latents[:, :, idx, :, :] = sample_schedulers[idx].step(
+ noise_pred[:, :, idx - valid_interval_start, :, :],
+ t[idx],
+ latents[:, :, idx, :, :],
+ return_dict=False,
+ )[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(step_matrix) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if accumulated_latents is None:
+ accumulated_latents = latents
+ else:
+ # Keep overlap frames for conditioning but don't include them in final output
+ accumulated_latents = torch.cat(
+ [accumulated_latents, latents[:, :, overlap_history_latent_frames:]], dim=2
+ )
+
+ latents = accumulated_latents
+
+ self._current_timestep = None
+
+ # Final decoding step - convert latents to pixels
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video_generated = self.vae.decode(latents, return_dict=False)[0]
+ video = torch.cat([video_original, video_generated], dim=2)
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
new file mode 100644
index 0000000000..d59b4ce3cb
--- /dev/null
+++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py
@@ -0,0 +1,745 @@
+# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import html
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import PIL
+import regex as re
+import torch
+from transformers import AutoTokenizer, CLIPProcessor, CLIPVisionModelWithProjection, UMT5EncoderModel
+
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
+from ...image_processor import PipelineImageInput
+from ...loaders import SkyReelsV2LoraLoaderMixin
+from ...models import AutoencoderKLWan, SkyReelsV2Transformer3DModel
+from ...schedulers import UniPCMultistepScheduler
+from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
+from ...utils.torch_utils import randn_tensor
+from ...video_processor import VideoProcessor
+from ..pipeline_utils import DiffusionPipeline
+from .pipeline_output import SkyReelsV2PipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+if is_ftfy_available():
+ import ftfy
+
+
+EXAMPLE_DOC_STRING = """\
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import (
+ ... SkyReelsV2ImageToVideoPipeline,
+ ... UniPCMultistepScheduler,
+ ... AutoencoderKLWan,
+ ... )
+ >>> from diffusers.utils import export_to_video
+ >>> from PIL import Image
+
+ >>> # Load the pipeline
+ >>> # Available models:
+ >>> # - Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-I2V-14B-540P-Diffusers
+ >>> # - Skywork/SkyReels-V2-I2V-14B-720P-Diffusers
+ >>> vae = AutoencoderKLWan.from_pretrained(
+ ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
+ ... subfolder="vae",
+ ... torch_dtype=torch.float32,
+ ... )
+ >>> pipe = SkyReelsV2ImageToVideoPipeline.from_pretrained(
+ ... "Skywork/SkyReels-V2-I2V-14B-720P-Diffusers",
+ ... vae=vae,
+ ... torch_dtype=torch.bfloat16,
+ ... )
+ >>> flow_shift = 5.0 # 8.0 for T2V, 5.0 for I2V
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
+ >>> image = Image.open("path/to/image.png")
+
+ >>> output = pipe(
+ ... image=image,
+ ... prompt=prompt,
+ ... num_inference_steps=50,
+ ... height=544,
+ ... width=960,
+ ... guidance_scale=5.0, # 6.0 for T2V, 5.0 for I2V
+ ... num_frames=97,
+ ... ).frames[0]
+ >>> export_to_video(output, "video.mp4", fps=24, quality=8)
+ ```
+"""
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r"\s+", " ", text)
+ text = text.strip()
+ return text
+
+
+def prompt_clean(text):
+ text = whitespace_clean(basic_clean(text))
+ return text
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+ return encoder_output.latent_dist.sample(generator)
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+ return encoder_output.latent_dist.mode()
+ elif hasattr(encoder_output, "latents"):
+ return encoder_output.latents
+ else:
+ raise AttributeError("Could not access latents of provided encoder_output")
+
+
+class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
+ r"""
+ Pipeline for Image-to-Video (i2v) generation using SkyReels-V2.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ Args:
+ tokenizer ([`T5Tokenizer`]):
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ text_encoder ([`T5EncoderModel`]):
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
+ image_encoder ([`CLIPVisionModelWithProjection`]):
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),
+ specifically the
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
+ variant.
+ transformer ([`SkyReelsV2Transformer3DModel`]):
+ Conditional Transformer to denoise the input latents.
+ scheduler ([`UniPCMultistepScheduler`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ vae ([`AutoencoderKLWan`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ tokenizer: AutoTokenizer,
+ text_encoder: UMT5EncoderModel,
+ image_encoder: CLIPVisionModelWithProjection,
+ image_processor: CLIPProcessor,
+ transformer: SkyReelsV2Transformer3DModel,
+ vae: AutoencoderKLWan,
+ scheduler: UniPCMultistepScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ image_encoder=image_encoder,
+ transformer=transformer,
+ scheduler=scheduler,
+ image_processor=image_processor,
+ )
+
+ self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
+ self.image_processor = image_processor
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline._get_t5_prompt_embeds
+ def _get_t5_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]] = None,
+ num_videos_per_prompt: int = 1,
+ max_sequence_length: int = 512,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ device = device or self._execution_device
+ dtype = dtype or self.text_encoder.dtype
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ prompt = [prompt_clean(u) for u in prompt]
+ batch_size = len(prompt)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ add_special_tokens=True,
+ return_attention_mask=True,
+ return_tensors="pt",
+ )
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
+ seq_lens = mask.gt(0).sum(dim=1).long()
+
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
+ prompt_embeds = torch.stack(
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
+ )
+
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ _, seq_len, _ = prompt_embeds.shape
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_image
+ def encode_image(
+ self,
+ image: PipelineImageInput,
+ device: Optional[torch.device] = None,
+ ):
+ device = device or self._execution_device
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
+ return image_embeds.hidden_states[-2]
+
+ # Copied from diffusers.pipelines.wan.pipeline_wan_i2v.WanImageToVideoPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ do_classifier_free_guidance: bool = True,
+ num_videos_per_prompt: int = 1,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ max_sequence_length: int = 226,
+ device: Optional[torch.device] = None,
+ dtype: Optional[torch.dtype] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ Whether to use classifier free guidance or not.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ device: (`torch.device`, *optional*):
+ torch device
+ dtype: (`torch.dtype`, *optional*):
+ torch dtype
+ """
+ device = device or self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt or ""
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
+ prompt=negative_prompt,
+ num_videos_per_prompt=num_videos_per_prompt,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ dtype=dtype,
+ )
+
+ return prompt_embeds, negative_prompt_embeds
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ image_embeds=None,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if image is not None and image_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ if image is None and image_embeds is None:
+ raise ValueError(
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
+ )
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
+ if height % 16 != 0 or width % 16 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ elif negative_prompt is not None and (
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
+ ):
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+
+ def prepare_latents(
+ self,
+ image: PipelineImageInput,
+ batch_size: int,
+ num_channels_latents: int = 16,
+ height: int = 480,
+ width: int = 832,
+ num_frames: int = 81,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
+ latent_height = height // self.vae_scale_factor_spatial
+ latent_width = width // self.vae_scale_factor_spatial
+
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device=device, dtype=dtype)
+
+ image = image.unsqueeze(2)
+ if last_image is None:
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
+ )
+ else:
+ last_image = last_image.unsqueeze(2)
+ video_condition = torch.cat(
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
+ dim=2,
+ )
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
+
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+
+ if isinstance(generator, list):
+ latent_condition = [
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
+ ]
+ latent_condition = torch.cat(latent_condition)
+ else:
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
+
+ latent_condition = latent_condition.to(dtype)
+ latent_condition = (latent_condition - latents_mean) * latents_std
+
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
+
+ if last_image is None:
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
+ else:
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
+ first_frame_mask = mask_lat_size[:, :, 0:1]
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
+ mask_lat_size = mask_lat_size.transpose(1, 2)
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
+
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @property
+ def current_timestep(self):
+ return self._current_timestep
+
+ @property
+ def interrupt(self):
+ return self._interrupt
+
+ @property
+ def attention_kwargs(self):
+ return self._attention_kwargs
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image: PipelineImageInput,
+ prompt: Union[str, List[str]] = None,
+ negative_prompt: Union[str, List[str]] = None,
+ height: int = 544,
+ width: int = 960,
+ num_frames: int = 97,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 5.0,
+ num_videos_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ image_embeds: Optional[torch.Tensor] = None,
+ last_image: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "np",
+ return_dict: bool = True,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ callback_on_step_end: Optional[
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
+ ] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ max_sequence_length: int = 512,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ image (`PipelineImageInput`):
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ height (`int`, defaults to `544`):
+ The height of the generated video.
+ width (`int`, defaults to `960`):
+ The width of the generated video.
+ num_frames (`int`, defaults to `97`):
+ The number of frames in the generated video.
+ num_inference_steps (`int`, defaults to `50`):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, defaults to `5.0`):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `negative_prompt` input argument.
+ image_embeds (`torch.Tensor`, *optional*):
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
+ image embeddings are generated from the `image` input argument.
+ output_type (`str`, *optional*, defaults to `"np"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
+ attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ max_sequence_length (`int`, *optional*, defaults to `512`):
+ The maximum sequence length of the prompt.
+
+ Examples:
+
+ Returns:
+ [`~SkyReelsV2PipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`SkyReelsV2PipelineOutput`] is returned, otherwise a `tuple` is returned
+ where the first element is a list with the generated images and the second element is a list of `bool`s
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
+ """
+
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ negative_prompt,
+ image,
+ height,
+ width,
+ prompt_embeds,
+ negative_prompt_embeds,
+ image_embeds,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ if num_frames % self.vae_scale_factor_temporal != 1:
+ logger.warning(
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
+ )
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
+ num_frames = max(num_frames, 1)
+
+ self._guidance_scale = guidance_scale
+ self._attention_kwargs = attention_kwargs
+ self._current_timestep = None
+ self._interrupt = False
+
+ device = self._execution_device
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ # 3. Encode input prompt
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ num_videos_per_prompt=num_videos_per_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ max_sequence_length=max_sequence_length,
+ device=device,
+ )
+
+ # Encode image embedding
+ transformer_dtype = self.transformer.dtype
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
+ if negative_prompt_embeds is not None:
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
+
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.vae.config.z_dim
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
+ if last_image is not None:
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
+ device, dtype=torch.float32
+ )
+ latents, condition = self.prepare_latents(
+ image,
+ batch_size * num_videos_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ num_frames,
+ torch.float32,
+ device,
+ generator,
+ latents,
+ last_image,
+ )
+
+ # 6. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ self._num_timesteps = len(timesteps)
+
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ if self.interrupt:
+ continue
+
+ self._current_timestep = t
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ noise_pred = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.do_classifier_free_guidance:
+ noise_uncond = self.transformer(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ self._current_timestep = None
+
+ if not output_type == "latent":
+ latents = latents.to(self.vae.dtype)
+ latents_mean = (
+ torch.tensor(self.vae.config.latents_mean)
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
+ .to(latents.device, latents.dtype)
+ )
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
+ latents.device, latents.dtype
+ )
+ latents = latents / latents_std + latents_mean
+ video = self.vae.decode(latents, return_dict=False)[0]
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
+ else:
+ video = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video,)
+
+ return SkyReelsV2PipelineOutput(frames=video)
diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md
index 2dc538f858..164baeb0a4 100644
--- a/src/diffusers/pipelines/stable_diffusion/README.md
+++ b/src/diffusers/pipelines/stable_diffusion/README.md
@@ -28,7 +28,7 @@ download the weights with `git lfs install; git clone https://huggingface.co/sta
### Using Stable Diffusion without being logged into the Hub.
-If you want to download the model weights using a single Python line, you need to be logged in via `huggingface-cli login`.
+If you want to download the model weights using a single Python line, you need to be logged in via `hf auth login`.
```python
from diffusers import DiffusionPipeline
@@ -54,7 +54,7 @@ pipe = StableDiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
### Text-to-Image with default PLMS scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
@@ -69,7 +69,7 @@ image.save("astronaut_rides_horse.png")
### Text-to-Image with DDIM scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, DDIMScheduler
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
@@ -88,7 +88,7 @@ image.save("astronaut_rides_horse.png")
### Text-to-Image with K-LMS scheduler
```python
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler
lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
@@ -118,7 +118,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the scheduler. CycleDiffusion only supports stochastic schedulers.
# load the pipeline
-# make sure you're logged in with `huggingface-cli login`
+# make sure you're logged in with `hf auth login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
index bd8609a11a..06c2076816 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py
@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
index 6a952a7ae6..141d849ec3 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
# scale the initial noise by the standard deviation required by the scheduler
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
index 3f10764dc7..882fa98b07 100644
--- a/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
+++ b/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps
# Scale the initial noise by the standard deviation required by the scheduler
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
+ latents = latents * self.scheduler.init_noise_sigma
# 5. Add noise to image
noise_level = np.array([noise_level]).astype(np.int64)
diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
index ad4c4b0917..afee3f61e9 100644
--- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
+++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py
@@ -25,6 +25,7 @@ from transformers import (
T5TokenizerFast,
)
+from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
def __init__(
self,
@@ -923,6 +924,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
+
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
@@ -1109,10 +1113,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
- negative_pooled_prompt_embeds = callback_outputs.pop(
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
- )
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
diff --git a/src/diffusers/pipelines/wan/pipeline_wan.py b/src/diffusers/pipelines/wan/pipeline_wan.py
index 6df66118b0..f52bf33d81 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan.py
@@ -112,10 +112,20 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
+ stages. If not provided, only `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
- model_cpu_offload_seq = "text_encoder->transformer->vae"
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer_2"]
def __init__(
self,
@@ -124,6 +134,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ transformer_2: Optional[WanTransformer3DModel] = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False, # Wan2.2 ti2v
):
super().__init__()
@@ -133,10 +146,12 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
+ transformer_2=transformer_2,
)
-
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.register_to_config(boundary_ratio=boundary_ratio)
+ self.register_to_config(expand_timesteps=expand_timesteps)
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
@@ -270,6 +285,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -302,6 +318,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
def prepare_latents(
self,
batch_size: int,
@@ -369,6 +388,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -407,6 +427,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -470,7 +495,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -520,36 +549,61 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents,
)
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
+
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
+
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
+
latent_model_input = latents.to(transformer_dtype)
- timestep = t.expand(latents.shape[0])
+ if self.config.expand_timesteps:
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
-
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
index c71138a97d..a072824a48 100644
--- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
+++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py
@@ -149,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
+ `transformer` is used.
+ boundary_ratio (`float`, *optional*, defaults to `None`):
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
"""
- model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+ _optional_components = ["transformer_2", "image_encoder", "image_processor"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
- image_encoder: CLIPVisionModel,
- image_processor: CLIPImageProcessor,
transformer: WanTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: FlowMatchEulerDiscreteScheduler,
+ image_processor: CLIPImageProcessor = None,
+ image_encoder: CLIPVisionModel = None,
+ transformer_2: WanTransformer3DModel = None,
+ boundary_ratio: Optional[float] = None,
+ expand_timesteps: bool = False,
):
super().__init__()
@@ -174,10 +187,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
transformer=transformer,
scheduler=scheduler,
image_processor=image_processor,
+ transformer_2=transformer_2,
)
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.image_processor = image_processor
@@ -325,6 +340,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds=None,
image_embeds=None,
callback_on_step_end_tensor_inputs=None,
+ guidance_scale_2=None,
):
if image is not None and image_embeds is not None:
raise ValueError(
@@ -368,6 +384,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
+
+ if self.config.boundary_ratio is not None and image_embeds is not None:
+ raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
+
def prepare_latents(
self,
image: PipelineImageInput,
@@ -398,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
else:
latents = latents.to(device=device, dtype=dtype)
- image = image.unsqueeze(2)
- if last_image is None:
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
+
+ if self.config.expand_timesteps:
+ video_condition = image
+
+ elif last_image is None:
video_condition = torch.cat(
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
)
@@ -432,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latent_condition = latent_condition.to(dtype)
latent_condition = (latent_condition - latents_mean) * latents_std
+ if self.config.expand_timesteps:
+ first_frame_mask = torch.ones(
+ 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
+ )
+ first_frame_mask[:, :, 0] = 0
+ return latents, latent_condition, first_frame_mask
+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
if last_image is None:
@@ -483,6 +516,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames: int = 81,
num_inference_steps: int = 50,
guidance_scale: float = 5.0,
+ guidance_scale_2: Optional[float] = None,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -527,6 +561,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
+ and the pipeline's `boundary_ratio` are not None.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -589,6 +627,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
negative_prompt_embeds,
image_embeds,
callback_on_step_end_tensor_inputs,
+ guidance_scale_2,
)
if num_frames % self.vae_scale_factor_temporal != 1:
@@ -598,7 +637,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
num_frames = max(num_frames, 1)
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
+ guidance_scale_2 = guidance_scale
+
self._guidance_scale = guidance_scale
+ self._guidance_scale_2 = guidance_scale_2
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
@@ -631,13 +674,14 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
- if image_embeds is None:
- if last_image is None:
- image_embeds = self.encode_image(image, device)
- else:
- image_embeds = self.encode_image([image, last_image], device)
- image_embeds = image_embeds.repeat(batch_size, 1, 1)
- image_embeds = image_embeds.to(transformer_dtype)
+ if self.config.boundary_ratio is None and not self.config.expand_timesteps:
+ if image_embeds is None:
+ if last_image is None:
+ image_embeds = self.encode_image(image, device)
+ else:
+ image_embeds = self.encode_image([image, last_image], device)
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
+ image_embeds = image_embeds.to(transformer_dtype)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -650,7 +694,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
device, dtype=torch.float32
)
- latents, condition = self.prepare_latents(
+
+ latents_outputs = self.prepare_latents(
image,
batch_size * num_videos_per_prompt,
num_channels_latents,
@@ -663,39 +708,69 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
latents,
last_image,
)
+ if self.config.expand_timesteps:
+ latents, condition, first_frame_mask = latents_outputs
+ else:
+ latents, condition = latents_outputs
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
+ if self.config.boundary_ratio is not None:
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
- latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
- timestep = t.expand(latents.shape[0])
- noise_pred = self.transformer(
- hidden_states=latent_model_input,
- timestep=timestep,
- encoder_hidden_states=prompt_embeds,
- encoder_hidden_states_image=image_embeds,
- attention_kwargs=attention_kwargs,
- return_dict=False,
- )[0]
+ if boundary_timestep is None or t >= boundary_timestep:
+ # wan2.1 or high-noise stage in wan2.2
+ current_model = self.transformer
+ current_guidance_scale = guidance_scale
+ else:
+ # low-noise stage in wan2.2
+ current_model = self.transformer_2
+ current_guidance_scale = guidance_scale_2
- if self.do_classifier_free_guidance:
- noise_uncond = self.transformer(
+ if self.config.expand_timesteps:
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
+ latent_model_input = latent_model_input.to(transformer_dtype)
+
+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
+ # batch_size, seq_len
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
+ else:
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
+ timestep = t.expand(latents.shape[0])
+
+ with current_model.cache_context("cond"):
+ noise_pred = current_model(
hidden_states=latent_model_input,
timestep=timestep,
- encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
+
+ if self.do_classifier_free_guidance:
+ with current_model.cache_context("uncond"):
+ noise_uncond = current_model(
+ hidden_states=latent_model_input,
+ timestep=timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ encoder_hidden_states_image=image_embeds,
+ attention_kwargs=attention_kwargs,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -719,6 +794,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
self._current_timestep = None
+ if self.config.expand_timesteps:
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
+
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
diff --git a/src/diffusers/quantizers/__init__.py b/src/diffusers/quantizers/__init__.py
index efd2418753..3ca867c129 100644
--- a/src/diffusers/quantizers/__init__.py
+++ b/src/diffusers/quantizers/__init__.py
@@ -12,183 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
-from typing import Dict, List, Optional, Union
-from ..utils import is_transformers_available, logging
from .auto import DiffusersAutoQuantizer
from .base import DiffusersQuantizer
-from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
-
-
-try:
- from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
-except ImportError:
-
- class TransformersQuantConfigMixin:
- pass
-
-
-logger = logging.get_logger(__name__)
-
-
-class PipelineQuantizationConfig:
- """
- Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
-
- Args:
- quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
- is available to both `diffusers` and `transformers`.
- quant_kwargs (`dict`): Params to initialize the quantization backend class.
- components_to_quantize (`list`): Components of a pipeline to be quantized.
- quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
- components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
- and `components_to_quantize`.
- """
-
- def __init__(
- self,
- quant_backend: str = None,
- quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
- components_to_quantize: Optional[List[str]] = None,
- quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
- ):
- self.quant_backend = quant_backend
- # Initialize kwargs to be {} to set to the defaults.
- self.quant_kwargs = quant_kwargs or {}
- self.components_to_quantize = components_to_quantize
- self.quant_mapping = quant_mapping
-
- self.post_init()
-
- def post_init(self):
- quant_mapping = self.quant_mapping
- self.is_granular = True if quant_mapping is not None else False
-
- self._validate_init_args()
-
- def _validate_init_args(self):
- if self.quant_backend and self.quant_mapping:
- raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
-
- if not self.quant_mapping and not self.quant_backend:
- raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
-
- if not self.quant_kwargs and not self.quant_mapping:
- raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
-
- if self.quant_backend is not None:
- self._validate_init_kwargs_in_backends()
-
- if self.quant_mapping is not None:
- self._validate_quant_mapping_args()
-
- def _validate_init_kwargs_in_backends(self):
- quant_backend = self.quant_backend
-
- self._check_backend_availability(quant_backend)
-
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
-
- if quant_config_mapping_transformers is not None:
- init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
- init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
- else:
- init_kwargs_transformers = None
-
- init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
- init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
-
- if init_kwargs_transformers != init_kwargs_diffusers:
- raise ValueError(
- "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
- f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
- "this mapping would look like."
- )
-
- def _validate_quant_mapping_args(self):
- quant_mapping = self.quant_mapping
- transformers_map, diffusers_map = self._get_quant_config_list()
-
- available_transformers = list(transformers_map.values()) if transformers_map else None
- available_diffusers = list(diffusers_map.values())
-
- for module_name, config in quant_mapping.items():
- if any(isinstance(config, cfg) for cfg in available_diffusers):
- continue
-
- if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
- continue
-
- if available_transformers:
- raise ValueError(
- f"Provided config for module_name={module_name} could not be found. "
- f"Available diffusers configs: {available_diffusers}; "
- f"Available transformers configs: {available_transformers}."
- )
- else:
- raise ValueError(
- f"Provided config for module_name={module_name} could not be found. "
- f"Available diffusers configs: {available_diffusers}."
- )
-
- def _check_backend_availability(self, quant_backend: str):
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
-
- available_backends_transformers = (
- list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
- )
- available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
-
- if (
- available_backends_transformers and quant_backend not in available_backends_transformers
- ) or quant_backend not in quant_config_mapping_diffusers:
- error_message = f"Provided quant_backend={quant_backend} was not found."
- if available_backends_transformers:
- error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
- error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
- raise ValueError(error_message)
-
- def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
- quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
-
- quant_mapping = self.quant_mapping
- components_to_quantize = self.components_to_quantize
-
- # Granular case
- if self.is_granular and module_name in quant_mapping:
- logger.debug(f"Initializing quantization config class for {module_name}.")
- config = quant_mapping[module_name]
- return config
-
- # Global config case
- else:
- should_quantize = False
- # Only quantize the modules requested for.
- if components_to_quantize and module_name in components_to_quantize:
- should_quantize = True
- # No specification for `components_to_quantize` means all modules should be quantized.
- elif not self.is_granular and not components_to_quantize:
- should_quantize = True
-
- if should_quantize:
- logger.debug(f"Initializing quantization config class for {module_name}.")
- mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
- quant_config_cls = mapping_to_use[self.quant_backend]
- quant_kwargs = self.quant_kwargs
- return quant_config_cls(**quant_kwargs)
-
- # Fallback: no applicable configuration found.
- return None
-
- def _get_quant_config_list(self):
- if is_transformers_available():
- from transformers.quantizers.auto import (
- AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
- )
- else:
- quant_config_mapping_transformers = None
-
- from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
-
- return quant_config_mapping_transformers, quant_config_mapping_diffusers
+from .pipe_quant_config import PipelineQuantizationConfig
diff --git a/src/diffusers/quantizers/pipe_quant_config.py b/src/diffusers/quantizers/pipe_quant_config.py
new file mode 100644
index 0000000000..5d02de16fd
--- /dev/null
+++ b/src/diffusers/quantizers/pipe_quant_config.py
@@ -0,0 +1,202 @@
+# Copyright 2025 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Dict, List, Optional, Union
+
+from ..utils import is_transformers_available, logging
+from .quantization_config import QuantizationConfigMixin as DiffQuantConfigMixin
+
+
+try:
+ from transformers.utils.quantization_config import QuantizationConfigMixin as TransformersQuantConfigMixin
+except ImportError:
+
+ class TransformersQuantConfigMixin:
+ pass
+
+
+logger = logging.get_logger(__name__)
+
+
+class PipelineQuantizationConfig:
+ """
+ Configuration class to be used when applying quantization on-the-fly to [`~DiffusionPipeline.from_pretrained`].
+
+ Args:
+ quant_backend (`str`): Quantization backend to be used. When using this option, we assume that the backend
+ is available to both `diffusers` and `transformers`.
+ quant_kwargs (`dict`): Params to initialize the quantization backend class.
+ components_to_quantize (`list`): Components of a pipeline to be quantized.
+ quant_mapping (`dict`): Mapping defining the quantization specs to be used for the pipeline
+ components. When using this argument, users are not expected to provide `quant_backend`, `quant_kawargs`,
+ and `components_to_quantize`.
+ """
+
+ def __init__(
+ self,
+ quant_backend: str = None,
+ quant_kwargs: Dict[str, Union[str, float, int, dict]] = None,
+ components_to_quantize: Optional[List[str]] = None,
+ quant_mapping: Dict[str, Union[DiffQuantConfigMixin, "TransformersQuantConfigMixin"]] = None,
+ ):
+ self.quant_backend = quant_backend
+ # Initialize kwargs to be {} to set to the defaults.
+ self.quant_kwargs = quant_kwargs or {}
+ self.components_to_quantize = components_to_quantize
+ self.quant_mapping = quant_mapping
+ self.config_mapping = {} # book-keeping Example: `{module_name: quant_config}`
+ self.post_init()
+
+ def post_init(self):
+ quant_mapping = self.quant_mapping
+ self.is_granular = True if quant_mapping is not None else False
+
+ self._validate_init_args()
+
+ def _validate_init_args(self):
+ if self.quant_backend and self.quant_mapping:
+ raise ValueError("Both `quant_backend` and `quant_mapping` cannot be specified at the same time.")
+
+ if not self.quant_mapping and not self.quant_backend:
+ raise ValueError("Must provide a `quant_backend` when not providing a `quant_mapping`.")
+
+ if not self.quant_kwargs and not self.quant_mapping:
+ raise ValueError("Both `quant_kwargs` and `quant_mapping` cannot be None.")
+
+ if self.quant_backend is not None:
+ self._validate_init_kwargs_in_backends()
+
+ if self.quant_mapping is not None:
+ self._validate_quant_mapping_args()
+
+ def _validate_init_kwargs_in_backends(self):
+ quant_backend = self.quant_backend
+
+ self._check_backend_availability(quant_backend)
+
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ if quant_config_mapping_transformers is not None:
+ init_kwargs_transformers = inspect.signature(quant_config_mapping_transformers[quant_backend].__init__)
+ init_kwargs_transformers = {name for name in init_kwargs_transformers.parameters if name != "self"}
+ else:
+ init_kwargs_transformers = None
+
+ init_kwargs_diffusers = inspect.signature(quant_config_mapping_diffusers[quant_backend].__init__)
+ init_kwargs_diffusers = {name for name in init_kwargs_diffusers.parameters if name != "self"}
+
+ if init_kwargs_transformers != init_kwargs_diffusers:
+ raise ValueError(
+ "The signatures of the __init__ methods of the quantization config classes in `diffusers` and `transformers` don't match. "
+ f"Please provide a `quant_mapping` instead, in the {self.__class__.__name__} class. Refer to [the docs](https://huggingface.co/docs/diffusers/main/en/quantization/overview#pipeline-level-quantization) to learn more about how "
+ "this mapping would look like."
+ )
+
+ def _validate_quant_mapping_args(self):
+ quant_mapping = self.quant_mapping
+ transformers_map, diffusers_map = self._get_quant_config_list()
+
+ available_transformers = list(transformers_map.values()) if transformers_map else None
+ available_diffusers = list(diffusers_map.values())
+
+ for module_name, config in quant_mapping.items():
+ if any(isinstance(config, cfg) for cfg in available_diffusers):
+ continue
+
+ if available_transformers and any(isinstance(config, cfg) for cfg in available_transformers):
+ continue
+
+ if available_transformers:
+ raise ValueError(
+ f"Provided config for module_name={module_name} could not be found. "
+ f"Available diffusers configs: {available_diffusers}; "
+ f"Available transformers configs: {available_transformers}."
+ )
+ else:
+ raise ValueError(
+ f"Provided config for module_name={module_name} could not be found. "
+ f"Available diffusers configs: {available_diffusers}."
+ )
+
+ def _check_backend_availability(self, quant_backend: str):
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ available_backends_transformers = (
+ list(quant_config_mapping_transformers.keys()) if quant_config_mapping_transformers else None
+ )
+ available_backends_diffusers = list(quant_config_mapping_diffusers.keys())
+
+ if (
+ available_backends_transformers and quant_backend not in available_backends_transformers
+ ) or quant_backend not in quant_config_mapping_diffusers:
+ error_message = f"Provided quant_backend={quant_backend} was not found."
+ if available_backends_transformers:
+ error_message += f"\nAvailable ones (transformers): {available_backends_transformers}."
+ error_message += f"\nAvailable ones (diffusers): {available_backends_diffusers}."
+ raise ValueError(error_message)
+
+ def _resolve_quant_config(self, is_diffusers: bool = True, module_name: str = None):
+ quant_config_mapping_transformers, quant_config_mapping_diffusers = self._get_quant_config_list()
+
+ quant_mapping = self.quant_mapping
+ components_to_quantize = self.components_to_quantize
+
+ # Granular case
+ if self.is_granular and module_name in quant_mapping:
+ logger.debug(f"Initializing quantization config class for {module_name}.")
+ config = quant_mapping[module_name]
+ self.config_mapping.update({module_name: config})
+ return config
+
+ # Global config case
+ else:
+ should_quantize = False
+ # Only quantize the modules requested for.
+ if components_to_quantize and module_name in components_to_quantize:
+ should_quantize = True
+ # No specification for `components_to_quantize` means all modules should be quantized.
+ elif not self.is_granular and not components_to_quantize:
+ should_quantize = True
+
+ if should_quantize:
+ logger.debug(f"Initializing quantization config class for {module_name}.")
+ mapping_to_use = quant_config_mapping_diffusers if is_diffusers else quant_config_mapping_transformers
+ quant_config_cls = mapping_to_use[self.quant_backend]
+ quant_kwargs = self.quant_kwargs
+ quant_obj = quant_config_cls(**quant_kwargs)
+ self.config_mapping.update({module_name: quant_obj})
+ return quant_obj
+
+ # Fallback: no applicable configuration found.
+ return None
+
+ def _get_quant_config_list(self):
+ if is_transformers_available():
+ from transformers.quantizers.auto import (
+ AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_transformers,
+ )
+ else:
+ quant_config_mapping_transformers = None
+
+ from ..quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING as quant_config_mapping_diffusers
+
+ return quant_config_mapping_transformers, quant_config_mapping_diffusers
+
+ def __repr__(self):
+ out = ""
+ config_mapping = dict(sorted(self.config_mapping.copy().items()))
+ for module_name, config in config_mapping.items():
+ out += f"{module_name} {config}"
+ return out
diff --git a/src/diffusers/schedulers/scheduling_deis_multistep.py b/src/diffusers/schedulers/scheduling_deis_multistep.py
index 748a7e39c0..7d8685ba10 100644
--- a/src/diffusers/schedulers/scheduling_deis_multistep.py
+++ b/src/diffusers/schedulers/scheduling_deis_multistep.py
@@ -153,6 +153,8 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
flow_shift: Optional[float] = 1.0,
timestep_spacing: str = "linspace",
steps_offset: int = 0,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -232,7 +234,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
+ ):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -242,6 +246,9 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = (
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
index 1a648af5a0..d07ff8b200 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
@@ -230,6 +230,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
timestep_spacing: str = "linspace",
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -330,6 +332,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
+ mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
@@ -345,6 +348,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
based on the `timestep_spacing` attribute. If `timesteps` is passed, `num_inference_steps` and `sigmas`
must be `None`, and `timestep_spacing` attribute will be ignored.
"""
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
index 9e3e830039..8663210a62 100644
--- a/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
+++ b/src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py
@@ -169,6 +169,8 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
variance_type: Optional[str] = None,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -301,6 +303,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
+ mu: Optional[float] = None,
timesteps: Optional[List[int]] = None,
):
"""
@@ -316,6 +319,9 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
timestep spacing strategy of equal spacing between timesteps schedule is used. If `timesteps` is
passed, `num_inference_steps` must be `None`.
"""
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if num_inference_steps is None and timesteps is None:
raise ValueError("Must pass exactly one of `num_inference_steps` or `timesteps`.")
if num_inference_steps is not None and timesteps is not None:
diff --git a/src/diffusers/schedulers/scheduling_scm.py b/src/diffusers/schedulers/scheduling_scm.py
index acff268c9b..63b4a109ff 100644
--- a/src/diffusers/schedulers/scheduling_scm.py
+++ b/src/diffusers/schedulers/scheduling_scm.py
@@ -168,7 +168,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(max_timesteps, 0, num_inference_steps + 1, device=device).float()
- print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
diff --git a/src/diffusers/schedulers/scheduling_unipc_multistep.py b/src/diffusers/schedulers/scheduling_unipc_multistep.py
index 8b1f699b10..162a34bd27 100644
--- a/src/diffusers/schedulers/scheduling_unipc_multistep.py
+++ b/src/diffusers/schedulers/scheduling_unipc_multistep.py
@@ -168,6 +168,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
use_beta_sigmas (`bool`, *optional*, defaults to `False`):
Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta
Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information.
+ use_flow_sigmas (`bool`, *optional*, defaults to `False`):
+ Whether to use flow sigmas for step sizes in the noise schedule during the sampling process.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
@@ -212,6 +214,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
steps_offset: int = 0,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
rescale_betas_zero_snr: bool = False,
+ use_dynamic_shifting: bool = False,
+ time_shift_type: str = "exponential",
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -298,7 +302,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
- def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ def set_timesteps(
+ self, num_inference_steps: int, device: Union[str, torch.device] = None, mu: Optional[float] = None
+ ):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -309,6 +315,9 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
+ if mu is not None:
+ assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
+ self.config.flow_shift = np.exp(mu)
if self.config.timestep_spacing == "linspace":
timesteps = (
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py
index 61d3e5a22f..f0e162ea6b 100644
--- a/src/diffusers/schedulers/scheduling_utils.py
+++ b/src/diffusers/schedulers/scheduling_utils.py
@@ -140,8 +140,8 @@ class SchedulerMixin(PushToHubMixin):
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
- `huggingface-cli login`. You can also activate the special
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
+ auth login`. You can also activate the special
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
firewalled environment.
diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py
index abcde6c386..e6ac78f63e 100644
--- a/src/diffusers/schedulers/scheduling_utils_flax.py
+++ b/src/diffusers/schedulers/scheduling_utils_flax.py
@@ -120,7 +120,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
+ It is required to be logged in (`hf auth login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py
index bc30411d87..d33b80dba0 100644
--- a/src/diffusers/training_utils.py
+++ b/src/diffusers/training_utils.py
@@ -3,12 +3,16 @@ import copy
import gc
import math
import random
+import re
+import warnings
+from contextlib import contextmanager
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
+from .pipelines import DiffusionPipeline
from .schedulers import SchedulerMixin
from .utils import (
convert_state_dict_to_diffusers,
@@ -316,6 +320,79 @@ def free_memory():
torch.xpu.empty_cache()
+@contextmanager
+def offload_models(
+ *modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
+):
+ """
+ Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
+ device on exit.
+
+ Args:
+ device (`str` or `torch.Device`): Device to move the `modules` to.
+ offload (`bool`): Flag to enable offloading.
+ """
+ if offload:
+ is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
+ # record where each module was
+ if is_model:
+ original_devices = [next(m.parameters()).device for m in modules]
+ else:
+ assert len(modules) == 1
+ original_devices = modules[0].device
+ # move to target device
+ for m in modules:
+ m.to(device)
+
+ try:
+ yield
+ finally:
+ if offload:
+ # move back to original devices
+ for m, orig_dev in zip(modules, original_devices):
+ m.to(orig_dev)
+
+
+def parse_buckets_string(buckets_str):
+ """Parses a string defining buckets into a list of (height, width) tuples."""
+ if not buckets_str:
+ raise ValueError("Bucket string cannot be empty.")
+
+ bucket_pairs = buckets_str.strip().split(";")
+ parsed_buckets = []
+ for pair_str in bucket_pairs:
+ match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
+ if not match:
+ raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
+ try:
+ height = int(match.group(1))
+ width = int(match.group(2))
+ if height <= 0 or width <= 0:
+ raise ValueError("Bucket dimensions must be positive integers.")
+ if height % 8 != 0 or width % 8 != 0:
+ warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
+ parsed_buckets.append((height, width))
+ except ValueError as e:
+ raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
+
+ if not parsed_buckets:
+ raise ValueError("No valid buckets found in the provided string.")
+
+ return parsed_buckets
+
+
+def find_nearest_bucket(h, w, bucket_options):
+ """Finds the closes bucket to the given height and width."""
+ min_metric = float("inf")
+ best_bucket_idx = None
+ for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
+ metric = abs(h * bucket_w - w * bucket_h)
+ if metric <= min_metric:
+ min_metric = metric
+ best_bucket_idx = bucket_idx
+ return best_bucket_idx
+
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py
index 2df05cb8eb..cadcedb98a 100644
--- a/src/diffusers/utils/__init__.py
+++ b/src/diffusers/utils/__init__.py
@@ -67,6 +67,9 @@ from .import_utils import (
is_bitsandbytes_version,
is_bs4_available,
is_cosmos_guardrail_available,
+ is_flash_attn_3_available,
+ is_flash_attn_available,
+ is_flash_attn_version,
is_flax_available,
is_ftfy_available,
is_gguf_available,
@@ -90,6 +93,8 @@ from .import_utils import (
is_peft_version,
is_pytorch_retinaface_available,
is_safetensors_available,
+ is_sageattention_available,
+ is_sageattention_version,
is_scipy_available,
is_sentencepiece_available,
is_tensorboard_available,
@@ -108,6 +113,7 @@ from .import_utils import (
is_unidecode_available,
is_wandb_available,
is_xformers_available,
+ is_xformers_version,
requires_backends,
)
from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video
diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py
index 7c04287d33..f8f04cc03a 100644
--- a/src/diffusers/utils/constants.py
+++ b/src/diffusers/utils/constants.py
@@ -41,6 +41,8 @@ DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60
+DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
+DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 2981f3a420..901aec4b22 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -2,6 +2,126 @@
from ..utils import DummyObject, requires_backends
+class AdaptiveProjectedGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class AutoGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ClassifierFreeGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ClassifierFreeZeroStarGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class PerturbedAttentionGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SkipLayerGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class SmoothedEnergyGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class TangentialClassifierFreeGuidance(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class FasterCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -17,6 +137,21 @@ class FasterCacheConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class FirstBlockCacheConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class HookRegistry(metaclass=DummyObject):
_backends = ["torch"]
@@ -32,6 +167,21 @@ class HookRegistry(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class LayerSkipConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
_backends = ["torch"]
@@ -47,10 +197,33 @@ class PyramidAttentionBroadcastConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class SmoothedEnergyGuidanceConfig(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
def apply_faster_cache(*args, **kwargs):
requires_backends(apply_faster_cache, ["torch"])
+def apply_first_block_cache(*args, **kwargs):
+ requires_backends(apply_first_block_cache, ["torch"])
+
+
+def apply_layer_skip(*args, **kwargs):
+ requires_backends(apply_layer_skip, ["torch"])
+
+
def apply_pyramid_attention_broadcast(*args, **kwargs):
requires_backends(apply_pyramid_attention_broadcast, ["torch"])
@@ -85,6 +258,21 @@ class AsymmetricAutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class AttentionBackendName(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class AuraFlowTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -925,6 +1113,21 @@ class SD3Transformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+class SkyReelsV2Transformer3DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class SparseControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1180,6 +1383,70 @@ class WanVACETransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
+def attention_backend(*args, **kwargs):
+ requires_backends(attention_backend, ["torch"])
+
+
+class ComponentsManager(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ComponentSpec(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ModularPipeline(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
+class ModularPipelineBlocks(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
def get_constant_schedule(*args, **kwargs):
requires_backends(get_constant_schedule, ["torch"])
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 656a8ac6c6..20382eafea 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -2,6 +2,96 @@
from ..utils import DummyObject, requires_backends
+class FluxAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class StableDiffusionXLModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanAutoBlocks(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WanModularPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class AllegroPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -692,6 +782,36 @@ class FluxInpaintPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class FluxKontextInpaintPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class FluxKontextPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class FluxPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
@@ -1757,6 +1877,81 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
+class SkyReelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class SkyReelsV2DiffusionForcingPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class SkyReelsV2DiffusionForcingVideoToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class SkyReelsV2ImageToVideoPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class SkyReelsV2Pipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class StableAudioPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py
index 4878937ab2..74ed240bf0 100644
--- a/src/diffusers/utils/dynamic_modules_utils.py
+++ b/src/diffusers/utils/dynamic_modules_utils.py
@@ -20,8 +20,11 @@ import json
import os
import re
import shutil
+import signal
import sys
+import threading
from pathlib import Path
+from types import ModuleType
from typing import Dict, Optional, Union
from urllib import request
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
+TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
+_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -154,33 +159,87 @@ def check_imports(filename):
return get_relative_imports(filename)
-def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
+def _raise_timeout_error(signum, frame):
+ raise ValueError(
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
+ )
+
+
+def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
+ if trust_remote_code is None:
+ if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
+ prev_sig_handler = None
+ try:
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
+ signal.alarm(TIME_OUT_REMOTE_CODE)
+ while trust_remote_code is None:
+ answer = input(
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
+ f"Do you wish to run the custom code? [y/N] "
+ )
+ if answer.lower() in ["yes", "y", "1"]:
+ trust_remote_code = True
+ elif answer.lower() in ["no", "n", "0", ""]:
+ trust_remote_code = False
+ signal.alarm(0)
+ except Exception:
+ # OS which does not support signal.SIGALRM
+ raise ValueError(
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
+ )
+ finally:
+ if prev_sig_handler is not None:
+ signal.signal(signal.SIGALRM, prev_sig_handler)
+ signal.alarm(0)
+ elif has_remote_code:
+ # For the CI which puts the timeout at 0
+ _raise_timeout_error(None, None)
+
+ if has_remote_code and not trust_remote_code:
+ raise ValueError(
+ f"Loading {model_name} requires you to execute the configuration file in that"
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
+ " set the option `trust_remote_code=True` to remove this error."
+ )
+
+ return trust_remote_code
+
+
+def get_class_in_module(class_name, module_path, force_reload=False):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
- module_path = module_path.replace(os.path.sep, ".")
- try:
- module = importlib.import_module(module_path)
- except ModuleNotFoundError as e:
- # This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
- # separator. We do a bit of monkey patching to detect and fix this case.
- if not (
- pretrained_model_name_or_path is not None
- and "." in pretrained_model_name_or_path
- and module_path.startswith("diffusers_modules")
- and pretrained_model_name_or_path.replace("/", "--") in module_path
- ):
- raise e # We can't figure this one out, just reraise the original error
+ name = os.path.normpath(module_path)
+ if name.endswith(".py"):
+ name = name[:-3]
+ name = name.replace(os.path.sep, ".")
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
- corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
- corrected_path = corrected_path.replace(
- pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
- pretrained_model_name_or_path.replace("/", "--"),
- )
- module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
+ with _HF_REMOTE_CODE_LOCK:
+ if force_reload:
+ sys.modules.pop(name, None)
+ importlib.invalidate_caches()
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
+
+ module: ModuleType
+ if cached_module is None:
+ module = importlib.util.module_from_spec(module_spec)
+ # insert it into sys.modules before any loading begins
+ sys.modules[name] = module
+ else:
+ module = cached_module
+
+ module_spec.loader.exec_module(module)
if class_name is None:
return find_pipeline_class(module)
+
return getattr(module, class_name)
@@ -259,8 +318,8 @@ def get_cached_module_file(
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+ You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -446,8 +505,8 @@ def get_class_from_dynamic_module(
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
+ You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
@@ -472,4 +531,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
- return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)
+ return get_class_in_module(class_name, final_module)
diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py
index f80f96a342..cf85488b7a 100644
--- a/src/diffusers/utils/hub_utils.py
+++ b/src/diffusers/utils/hub_utils.py
@@ -304,8 +304,7 @@ def _get_model_file(
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
- "token having permission to this repo with `token` or log in with `huggingface-cli "
- "login`."
+ "token having permission to this repo with `token` or log in with `hf auth login`."
) from e
except RevisionNotFoundError as e:
raise EnvironmentError(
@@ -467,6 +466,7 @@ class PushToHubMixin:
token: Optional[str] = None,
commit_message: Optional[str] = None,
create_pr: bool = False,
+ subfolder: Optional[str] = None,
):
"""
Uploads all files in `working_dir` to `repo_id`.
@@ -481,7 +481,12 @@ class PushToHubMixin:
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
return upload_folder(
- repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
+ repo_id=repo_id,
+ folder_path=working_dir,
+ token=token,
+ commit_message=commit_message,
+ create_pr=create_pr,
+ path_in_repo=subfolder,
)
def push_to_hub(
@@ -493,6 +498,7 @@ class PushToHubMixin:
create_pr: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
+ subfolder: Optional[str] = None,
) -> str:
"""
Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.
@@ -508,8 +514,8 @@ class PushToHubMixin:
Whether to make the repo private. If `None` (default), the repo will be public unless the
organization's default is private. This value is ignored if the repo already exists.
token (`str`, *optional*):
- The token to use as HTTP bearer authorization for remote files. The token generated when running
- `huggingface-cli login` (stored in `~/.huggingface`).
+ The token to use as HTTP bearer authorization for remote files. The token generated when running `hf
+ auth login` (stored in `~/.huggingface`).
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
safe_serialization (`bool`, *optional*, defaults to `True`):
@@ -534,8 +540,9 @@ class PushToHubMixin:
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
# Create a new empty model card and eventually tag it
- model_card = load_or_create_model_card(repo_id, token=token)
- model_card = populate_model_card(model_card)
+ if not subfolder:
+ model_card = load_or_create_model_card(repo_id, token=token)
+ model_card = populate_model_card(model_card)
# Save all files.
save_kwargs = {"safe_serialization": safe_serialization}
@@ -546,7 +553,8 @@ class PushToHubMixin:
self.save_pretrained(tmpdir, **save_kwargs)
# Update model card if needed:
- model_card.save(os.path.join(tmpdir, "README.md"))
+ if not subfolder:
+ model_card.save(os.path.join(tmpdir, "README.md"))
return self._upload_folder(
tmpdir,
@@ -554,4 +562,5 @@ class PushToHubMixin:
token=token,
commit_message=commit_message,
create_pr=create_pr,
+ subfolder=subfolder,
)
diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py
index f12e9de331..a27c2da648 100644
--- a/src/diffusers/utils/import_utils.py
+++ b/src/diffusers/utils/import_utils.py
@@ -220,6 +220,9 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
+_sageattention_available, _sageattention_version = _is_package_available("sageattention")
+_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
+_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
def is_torch_available():
@@ -378,6 +381,18 @@ def is_hpu_available():
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
+def is_sageattention_available():
+ return _sageattention_available
+
+
+def is_flash_attn_available():
+ return _flash_attn_available
+
+
+def is_flash_attn_3_available():
+ return _flash_attn_3_available
+
+
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -804,6 +819,51 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)
+def is_xformers_version(operation: str, version: str):
+ """
+ Compares the current xformers version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _xformers_available:
+ return False
+ return compare_versions(parse(_xformers_version), operation, version)
+
+
+def is_sageattention_version(operation: str, version: str):
+ """
+ Compares the current sageattention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _sageattention_available:
+ return False
+ return compare_versions(parse(_sageattention_version), operation, version)
+
+
+def is_flash_attn_version(operation: str, version: str):
+ """
+ Compares the current flash-attention version to a given reference with an operation.
+
+ Args:
+ operation (`str`):
+ A string representation of an operator, such as `">"` or `"<="`
+ version (`str`):
+ A version string
+ """
+ if not _flash_attn_available:
+ return False
+ return compare_versions(parse(_flash_attn_version), operation, version)
+
+
def get_objects_from_module(module):
"""
Returns a dict of object names and values in a module, while skipping private/internal objects
diff --git a/src/diffusers/utils/peft_utils.py b/src/diffusers/utils/peft_utils.py
index 3907bdd5b3..651fa27294 100644
--- a/src/diffusers/utils/peft_utils.py
+++ b/src/diffusers/utils/peft_utils.py
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
module.set_scale(adapter_name, 1.0)
-def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
+def get_peft_kwargs(
+ rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
+):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
else:
lora_alpha = set(network_alpha_dict.values()).pop()
- # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
# for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,21 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
"use_dora": use_dora,
"lora_bias": lora_bias,
}
+
+ # Example: try load FusionX LoRA into Wan VACE
+ exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
+ if exclude_modules:
+ if not is_peft_version(">=", "0.14.0"):
+ msg = """
+It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
+version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
+peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
+https://github.com/huggingface/diffusers/issues/new
+ """
+ logger.debug(msg)
+ else:
+ lora_config_kwargs.update({"exclude_modules": exclude_modules})
+
return lora_config_kwargs
@@ -294,11 +310,7 @@ def check_peft_version(min_version: str) -> None:
def _create_lora_config(
- state_dict,
- network_alphas,
- metadata,
- rank_pattern_dict,
- is_unet: bool = True,
+ state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
from peft import LoraConfig
@@ -306,7 +318,12 @@ def _create_lora_config(
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
- rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
+ rank_pattern_dict,
+ network_alpha_dict=network_alphas,
+ peft_state_dict=state_dict,
+ is_unet=is_unet,
+ model_state_dict=model_state_dict,
+ adapter_name=adapter_name,
)
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -371,3 +388,27 @@ def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
if warn_msg:
logger.warning(warn_msg)
+
+
+def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
+ """
+ Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
+ `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
+ doesn't exist in `peft_state_dict`.
+ """
+ if model_state_dict is None:
+ return
+ all_modules = set()
+ string_to_replace = f"{adapter_name}." if adapter_name else ""
+
+ for name in model_state_dict.keys():
+ if string_to_replace:
+ name = name.replace(string_to_replace, "")
+ if "." in name:
+ module_name = name.rsplit(".", 1)[0]
+ all_modules.add(module_name)
+
+ target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
+ exclude_modules = list(all_modules - target_modules_set)
+
+ return exclude_modules
diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py
index e5da39c1d8..3d9444975d 100644
--- a/src/diffusers/utils/testing_utils.py
+++ b/src/diffusers/utils/testing_utils.py
@@ -1,4 +1,5 @@
import functools
+import glob
import importlib
import importlib.metadata
import inspect
@@ -18,7 +19,7 @@ from collections import UserDict
from contextlib import contextmanager
from io import BytesIO, StringIO
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
import numpy as np
import PIL.Image
@@ -421,6 +422,10 @@ def require_big_accelerator(test_case):
Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
Flux, SD3, Cog, etc.
"""
+ import pytest
+
+ test_case = pytest.mark.big_accelerator(test_case)
+
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
@@ -990,10 +995,10 @@ def pytest_terminal_summary_main(tr, id):
config.option.tbstyle = orig_tbstyle
-# Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
+# Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
"""
- To decorate flaky tests. They will be retried on failures.
+ To decorate flaky tests (methods or entire classes). They will be retried on failures.
Args:
max_attempts (`int`, *optional*, defaults to 5):
@@ -1005,22 +1010,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
etc.)
"""
- def decorator(test_func_ref):
- @functools.wraps(test_func_ref)
+ def decorator(obj):
+ # If decorating a class, wrap each test method on it
+ if inspect.isclass(obj):
+ for attr_name, attr_value in list(obj.__dict__.items()):
+ if callable(attr_value) and attr_name.startswith("test"):
+ # recursively decorate the method
+ setattr(obj, attr_name, decorator(attr_value))
+ return obj
+
+ # Otherwise we're decorating a single test function / method
+ @functools.wraps(obj)
def wrapper(*args, **kwargs):
retry_count = 1
-
while retry_count < max_attempts:
try:
- return test_func_ref(*args, **kwargs)
-
+ return obj(*args, **kwargs)
except Exception as err:
- print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
+ msg = (
+ f"[FLAKY] {description or obj.__name__!r} "
+ f"failed on attempt {retry_count}/{max_attempts}: {err}"
+ )
+ print(msg, file=sys.stderr)
if wait_before_retry is not None:
time.sleep(wait_before_retry)
retry_count += 1
- return test_func_ref(*args, **kwargs)
+ return obj(*args, **kwargs)
return wrapper
@@ -1377,6 +1393,103 @@ if TYPE_CHECKING:
else:
DevicePropertiesUserDict = UserDict
+if is_torch_available():
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.group_offloading import (
+ _GROUP_ID_LAZY_LEAF,
+ _compute_group_hash,
+ _find_parent_module_in_module_dict,
+ _gather_buffers_with_no_group_offloading_parent,
+ _gather_parameters_with_no_group_offloading_parent,
+ )
+
+ def _get_expected_safetensors_files(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> Set[str]:
+ expected_files = set()
+
+ def get_hashed_filename(group_id: str) -> str:
+ short_hash = _compute_group_hash(group_id)
+ return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
+
+ if offload_type == "block_level":
+ if num_blocks_per_group is None:
+ raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
+
+ # Handle groups of ModuleList and Sequential blocks
+ unmatched_modules = []
+ for name, submodule in module.named_children():
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
+ unmatched_modules.append(module)
+ continue
+
+ for i in range(0, len(submodule), num_blocks_per_group):
+ current_modules = submodule[i : i + num_blocks_per_group]
+ if not current_modules:
+ continue
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
+ expected_files.add(get_hashed_filename(group_id))
+
+ # Handle the group for unmatched top-level modules and parameters
+ for module in unmatched_modules:
+ expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
+
+ elif offload_type == "leaf_level":
+ # Handle leaf-level module groups
+ for name, submodule in module.named_modules():
+ if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
+ # These groups will always have parameters, so a file is expected
+ expected_files.add(get_hashed_filename(name))
+
+ # Handle groups for non-leaf parameters/buffers
+ modules_with_group_offloading = {
+ name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
+ }
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
+
+ all_orphans = parameters + buffers
+ if all_orphans:
+ parent_to_tensors = {}
+ module_dict = dict(module.named_modules())
+ for tensor_name, _ in all_orphans:
+ parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
+ if parent_name not in parent_to_tensors:
+ parent_to_tensors[parent_name] = []
+ parent_to_tensors[parent_name].append(tensor_name)
+
+ for parent_name in parent_to_tensors:
+ # A file is expected for each parent that gathers orphaned tensors
+ expected_files.add(get_hashed_filename(parent_name))
+ expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
+
+ else:
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
+
+ return expected_files
+
+ def _check_safetensors_serialization(
+ module: torch.nn.Module,
+ offload_to_disk_path: str,
+ offload_type: str,
+ num_blocks_per_group: Optional[int] = None,
+ ) -> bool:
+ if not os.path.isdir(offload_to_disk_path):
+ return False, None, None
+
+ expected_files = _get_expected_safetensors_files(
+ module, offload_to_disk_path, offload_type, num_blocks_per_group
+ )
+ actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
+ missing_files = expected_files - actual_files
+ extra_files = actual_files - expected_files
+
+ is_correct = not missing_files and not extra_files
+ return is_correct, extra_files, missing_files
+
class Expectations(DevicePropertiesUserDict):
def get_expectation(self) -> Any:
diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py
index ffc1119727..dd54cb2b91 100644
--- a/src/diffusers/utils/torch_utils.py
+++ b/src/diffusers/utils/torch_utils.py
@@ -92,6 +92,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
+def unwrap_module(module):
+ """Unwraps a module if it was compiled with torch.compile()"""
+ return module._orig_mod if is_compiled_module(module) else module
+
+
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
@@ -170,6 +175,8 @@ def get_device():
return "npu"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"
+ elif torch.backends.mps.is_available():
+ return "mps"
else:
return "cpu"
@@ -177,5 +184,14 @@ def get_device():
def empty_device_cache(device_type: Optional[str] = None):
if device_type is None:
device_type = get_device()
+ if device_type in ["cpu"]:
+ return
device_mod = getattr(torch, device_type, torch.cuda)
device_mod.empty_cache()
+
+
+def device_synchronize(device_type: Optional[str] = None):
+ if device_type is None:
+ device_type = get_device()
+ device_mod = getattr(torch, device_type, torch.cuda)
+ device_mod.synchronize()
diff --git a/tests/conftest.py b/tests/conftest.py
index 7e9c4e8f39..3237fb9c7b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -30,6 +30,10 @@ sys.path.insert(1, git_repo_path)
warnings.simplefilter(action="ignore", category=FutureWarning)
+def pytest_configure(config):
+ config.addinivalue_line("markers", "big_accelerator: marks tests as requiring big accelerator resources")
+
+
def pytest_addoption(parser):
from diffusers.utils.testing_utils import pytest_addoption_shared
diff --git a/tests/lora/test_deprecated_utilities.py b/tests/lora/test_deprecated_utilities.py
deleted file mode 100644
index 4275ef8089..0000000000
--- a/tests/lora/test_deprecated_utilities.py
+++ /dev/null
@@ -1,39 +0,0 @@
-import os
-import tempfile
-import unittest
-
-import torch
-
-from diffusers.loaders.lora_base import LoraBaseMixin
-
-
-class UtilityMethodDeprecationTests(unittest.TestCase):
- def test_fetch_state_dict_cls_method_raises_warning(self):
- state_dict = torch.nn.Linear(3, 3).state_dict()
- with self.assertWarns(FutureWarning) as warning:
- _ = LoraBaseMixin._fetch_state_dict(
- state_dict,
- weight_name=None,
- use_safetensors=False,
- local_files_only=True,
- cache_dir=None,
- force_download=False,
- proxies=None,
- token=None,
- revision=None,
- subfolder=None,
- user_agent=None,
- allow_pickle=None,
- )
- warning_message = str(warning.warnings[0].message)
- assert "Using the `_fetch_state_dict()` method from" in warning_message
-
- def test_best_guess_weight_name_cls_method_raises_warning(self):
- with tempfile.TemporaryDirectory() as tmpdir:
- state_dict = torch.nn.Linear(3, 3).state_dict()
- torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
-
- with self.assertWarns(FutureWarning) as warning:
- _ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
- warning_message = str(warning.warnings[0].message)
- assert "Using the `_best_guess_weight_name()` method from" in warning_message
diff --git a/tests/lora/test_lora_layers_cogvideox.py b/tests/lora/test_lora_layers_cogvideox.py
index bd7b33445c..565d6db697 100644
--- a/tests/lora/test_lora_layers_cogvideox.py
+++ b/tests/lora/test_lora_layers_cogvideox.py
@@ -16,6 +16,7 @@ import sys
import unittest
import torch
+from parameterized import parameterized
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
@@ -28,6 +29,7 @@ from diffusers import (
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
+ require_torch_accelerator,
)
@@ -127,6 +129,13 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
+ @parameterized.expand([("block_level", True), ("leaf_level", False)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
+ # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
+ super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
+
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
diff --git a/tests/lora/test_lora_layers_cogview4.py b/tests/lora/test_lora_layers_cogview4.py
index 23573bcb21..b7367d9b09 100644
--- a/tests/lora/test_lora_layers_cogview4.py
+++ b/tests/lora/test_lora_layers_cogview4.py
@@ -18,10 +18,17 @@ import unittest
import numpy as np
import torch
+from parameterized import parameterized
from transformers import AutoTokenizer, GlmModel
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps, torch_device
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ require_torch_accelerator,
+ skip_mps,
+ torch_device,
+)
sys.path.append(".")
@@ -141,6 +148,13 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"Loading from saved checkpoints should give same results.",
)
+ @parameterized.expand([("block_level", True), ("leaf_level", False)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ # TODO: We don't run the (leaf_level, True) test here that is enabled for other models.
+ # The reason for this can be found here: https://github.com/huggingface/diffusers/pull/11804#issuecomment-3013325338
+ super()._test_group_offloading_inference_denoiser(offload_type, use_stream)
+
@unittest.skip("Not supported in CogView4.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
diff --git a/tests/lora/test_lora_layers_flux.py b/tests/lora/test_lora_layers_flux.py
index 336ac2246f..95f1e137e9 100644
--- a/tests/lora/test_lora_layers_flux.py
+++ b/tests/lora/test_lora_layers_flux.py
@@ -20,7 +20,6 @@ import tempfile
import unittest
import numpy as np
-import pytest
import safetensors.torch
import torch
from parameterized import parameterized
@@ -813,7 +812,6 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on audace.
@@ -960,7 +958,6 @@ class FluxLoRAIntegrationTests(unittest.TestCase):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlLoRAIntegrationTests(unittest.TestCase):
num_inference_steps = 10
seed = 0
diff --git a/tests/lora/test_lora_layers_hunyuanvideo.py b/tests/lora/test_lora_layers_hunyuanvideo.py
index 19e31f320d..4cbd6523e7 100644
--- a/tests/lora/test_lora_layers_hunyuanvideo.py
+++ b/tests/lora/test_lora_layers_hunyuanvideo.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import CLIPTextModel, CLIPTokenizer, LlamaModel, LlamaTokenizerFast
@@ -198,7 +197,6 @@ class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class HunyuanVideoLoRAIntegrationTests(unittest.TestCase):
"""internal note: The integration slices were obtained on DGX.
diff --git a/tests/lora/test_lora_layers_sd.py b/tests/lora/test_lora_layers_sd.py
index a81128fa44..1c5a9b00e9 100644
--- a/tests/lora/test_lora_layers_sd.py
+++ b/tests/lora/test_lora_layers_sd.py
@@ -120,7 +120,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
# We will offload the first adapter in CPU and check if the offloading
@@ -187,7 +187,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(
check_if_lora_correctly_set(pipe.unet),
- "Lora not correctly set in text encoder",
+ "Lora not correctly set in unet",
)
for name, param in pipe.unet.named_parameters():
@@ -208,6 +208,53 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
if "lora_" in name:
self.assertNotEqual(param.device, torch.device("cpu"))
+ @slow
+ @require_torch_accelerator
+ def test_integration_set_lora_device_different_target_layers(self):
+ # fixes a bug that occurred when calling set_lora_device with multiple adapters loaded that target different
+ # layers, see #11833
+ from peft import LoraConfig
+
+ path = "stable-diffusion-v1-5/stable-diffusion-v1-5"
+ pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float16)
+ # configs partly target the same, partly different layers
+ config0 = LoraConfig(target_modules=["to_k", "to_v"])
+ config1 = LoraConfig(target_modules=["to_k", "to_q"])
+ pipe.unet.add_adapter(config0, adapter_name="adapter-0")
+ pipe.unet.add_adapter(config1, adapter_name="adapter-1")
+ pipe = pipe.to(torch_device)
+
+ self.assertTrue(
+ check_if_lora_correctly_set(pipe.unet),
+ "Lora not correctly set in unet",
+ )
+
+ # sanity check that the adapters don't target the same layers, otherwise the test passes even without the fix
+ modules_adapter_0 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-0")}
+ modules_adapter_1 = {n for n, _ in pipe.unet.named_modules() if n.endswith(".adapter-1")}
+ self.assertNotEqual(modules_adapter_0, modules_adapter_1)
+ self.assertTrue(modules_adapter_0 - modules_adapter_1)
+ self.assertTrue(modules_adapter_1 - modules_adapter_0)
+
+ # setting both separately works
+ pipe.set_lora_device(["adapter-0"], "cpu")
+ pipe.set_lora_device(["adapter-1"], "cpu")
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device == torch.device("cpu"))
+
+ # setting both at once also works
+ pipe.set_lora_device(["adapter-0", "adapter-1"], torch_device)
+
+ for name, module in pipe.unet.named_modules():
+ if "adapter-0" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+ elif "adapter-1" in name and not isinstance(module, (nn.Dropout, nn.Identity)):
+ self.assertTrue(module.weight.device != torch.device("cpu"))
+
@slow
@nightly
diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py
index 8a8f2a676d..8928ccbac2 100644
--- a/tests/lora/test_lora_layers_sd3.py
+++ b/tests/lora/test_lora_layers_sd3.py
@@ -17,7 +17,6 @@ import sys
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -139,7 +138,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_torch_accelerator
@require_peft_backend
@require_big_accelerator
-@pytest.mark.big_accelerator
class SD3LoraIntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/lora/test_lora_layers_wan.py b/tests/lora/test_lora_layers_wan.py
index 95ec44b2bf..fe26a56e77 100644
--- a/tests/lora/test_lora_layers_wan.py
+++ b/tests/lora/test_lora_layers_wan.py
@@ -24,7 +24,11 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
-from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ require_peft_backend,
+ skip_mps,
+)
sys.path.append(".")
diff --git a/tests/lora/test_lora_layers_wanvace.py b/tests/lora/test_lora_layers_wanvace.py
new file mode 100644
index 0000000000..f976577653
--- /dev/null
+++ b/tests/lora/test_lora_layers_wanvace.py
@@ -0,0 +1,222 @@
+# Copyright 2025 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+import tempfile
+import unittest
+
+import numpy as np
+import pytest
+import safetensors.torch
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, WanVACEPipeline, WanVACETransformer3DModel
+from diffusers.utils.import_utils import is_peft_available
+from diffusers.utils.testing_utils import (
+ floats_tensor,
+ is_flaky,
+ require_peft_backend,
+ require_peft_version_greater,
+ skip_mps,
+ torch_device,
+)
+
+
+if is_peft_available():
+ from peft.utils import get_peft_model_state_dict
+
+sys.path.append(".")
+
+from utils import PeftLoraLoaderMixinTests # noqa: E402
+
+
+@require_peft_backend
+@skip_mps
+@is_flaky(max_attempts=10, description="very flaky class")
+class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
+ pipeline_class = WanVACEPipeline
+ scheduler_cls = FlowMatchEulerDiscreteScheduler
+ scheduler_classes = [FlowMatchEulerDiscreteScheduler]
+ scheduler_kwargs = {}
+
+ transformer_kwargs = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 8,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 32,
+ "freq_dim": 16,
+ "ffn_dim": 16,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 16,
+ "vace_layers": [0],
+ "vace_in_channels": 72,
+ }
+ transformer_cls = WanVACETransformer3DModel
+ vae_kwargs = {
+ "base_dim": 3,
+ "z_dim": 4,
+ "dim_mult": [1, 1, 1, 1],
+ "latents_mean": torch.randn(4).numpy().tolist(),
+ "latents_std": torch.randn(4).numpy().tolist(),
+ "num_res_blocks": 1,
+ "temperal_downsample": [False, True, True],
+ }
+ vae_cls = AutoencoderKLWan
+ has_two_text_encoders = True
+ tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
+ text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
+
+ text_encoder_target_modules = ["q", "k", "v", "o"]
+
+ @property
+ def output_shape(self):
+ return (1, 9, 16, 16, 3)
+
+ def get_dummy_inputs(self, with_generator=True):
+ batch_size = 1
+ sequence_length = 16
+ num_channels = 4
+ num_frames = 9
+ num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
+ sizes = (4, 4)
+ height, width = 16, 16
+
+ generator = torch.manual_seed(0)
+ noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
+ input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
+ video = [Image.new("RGB", (height, width))] * num_frames
+ mask = [Image.new("L", (height, width), 0)] * num_frames
+
+ pipeline_inputs = {
+ "video": video,
+ "mask": mask,
+ "prompt": "",
+ "num_frames": num_frames,
+ "num_inference_steps": 1,
+ "guidance_scale": 6.0,
+ "height": height,
+ "width": height,
+ "max_sequence_length": sequence_length,
+ "output_type": "np",
+ }
+ if with_generator:
+ pipeline_inputs.update({"generator": generator})
+
+ return noise, input_ids, pipeline_inputs
+
+ def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
+ super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
+
+ def test_simple_inference_with_text_denoiser_lora_unfused(self):
+ super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
+ pass
+
+ @unittest.skip("Not supported in Wan VACE.")
+ def test_modify_padding_mode(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_partial_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_and_scale(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_fused(self):
+ pass
+
+ @unittest.skip("Text encoder LoRA is not supported in Wan VACE.")
+ def test_simple_inference_with_text_lora_save_load(self):
+ pass
+
+ @pytest.mark.xfail(
+ condition=True,
+ reason="RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same",
+ strict=True,
+ )
+ def test_layerwise_casting_inference_denoiser(self):
+ super().test_layerwise_casting_inference_denoiser()
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules_wanvace(self):
+ scheduler_cls = self.scheduler_classes[0]
+ exclude_module_name = "vace_blocks.0.proj_out"
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ denoiser_lora_config.target_modules = ["proj_out"]
+ denoiser_lora_config.exclude_modules = [exclude_module_name]
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ # The state dict shouldn't contain the modules to be excluded from LoRA.
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts)
+ pipe.unload_lora_weights()
+
+ # Check in the loaded state dict.
+ loaded_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
+ self.assertTrue(not any(exclude_module_name in k for k in loaded_state_dict))
+ self.assertTrue(any("proj_out" in k for k in loaded_state_dict))
+
+ # Check in the state dict obtained after loading LoRA.
+ pipe.load_lora_weights(tmpdir)
+ state_dict_from_model = get_peft_model_state_dict(pipe.transformer, adapter_name="default_0")
+ self.assertTrue(not any(exclude_module_name in k for k in state_dict_from_model))
+ self.assertTrue(any("proj_out" in k for k in state_dict_from_model))
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
+ def test_simple_inference_with_text_denoiser_lora_and_scale(self):
+ super().test_simple_inference_with_text_denoiser_lora_and_scale()
diff --git a/tests/lora/utils.py b/tests/lora/utils.py
index 93dc4a2c37..9edaeafc71 100644
--- a/tests/lora/utils.py
+++ b/tests/lora/utils.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import os
import re
@@ -39,6 +40,7 @@ from diffusers.utils.testing_utils import (
is_torch_version,
require_peft_backend,
require_peft_version_greater,
+ require_torch_accelerator,
require_transformers_version_greater,
skip_mps,
torch_device,
@@ -290,9 +292,21 @@ class PeftLoraLoaderMixinTests:
return modules_to_save
- def check_if_adapters_added_correctly(
- self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
- ):
+ def _get_exclude_modules(self, pipe):
+ from diffusers.utils.peft_utils import _derive_exclude_modules
+
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ denoiser = "unet" if self.unet_kwargs is not None else "transformer"
+ modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
+ denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
+ pipe.unload_lora_weights()
+ denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
+ exclude_modules = _derive_exclude_modules(
+ denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
+ )
+ return exclude_modules
+
+ def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
@@ -344,7 +358,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -427,7 +441,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -483,7 +497,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -521,7 +535,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
@@ -553,7 +567,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -588,7 +602,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -639,7 +653,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -690,7 +704,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
@@ -733,7 +747,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -774,7 +788,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
@@ -818,7 +832,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
@@ -856,7 +870,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
@@ -892,7 +906,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
@@ -1009,7 +1023,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1031,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
)
@@ -1758,7 +1772,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -1849,7 +1863,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
@@ -1936,7 +1950,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
@@ -2095,14 +2109,15 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
def test_layerwise_casting_inference_denoiser(self):
- from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
- if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if "lora" in name or any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2118,7 +2133,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
- pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
+ pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
@@ -2153,10 +2168,10 @@ class PeftLoraLoaderMixinTests:
See the docstring of [`hooks.layerwise_casting.PeftInputAutocastDisableHook`] for more details.
"""
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from diffusers.hooks.layerwise_casting import (
_PEFT_AUTOCAST_DISABLE_HOOK,
DEFAULT_SKIP_MODULES_PATTERN,
- SUPPORTED_PYTORCH_LAYERS,
apply_layerwise_casting,
)
@@ -2166,7 +2181,7 @@ class PeftLoraLoaderMixinTests:
def check_module(denoiser):
# This will also check if the peft layers are in torch.float8_e4m3fn dtype (unlike test_layerwise_casting_inference_denoiser)
for name, module in denoiser.named_modules():
- if not isinstance(module, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(module, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -2236,7 +2251,7 @@ class PeftLoraLoaderMixinTests:
)
pipe = self.pipeline_class(**components)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
@@ -2289,7 +2304,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
- pipe, _ = self.check_if_adapters_added_correctly(
+ pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
@@ -2308,6 +2323,77 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
+ def test_lora_unload_add_adapter(self):
+ """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # unload and then add.
+ pipe.unload_lora_weights()
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ @require_peft_version_greater("0.13.2")
+ def test_lora_exclude_modules(self):
+ """
+ Test to check if `exclude_modules` works or not. It works in the following way:
+ we first create a pipeline and insert LoRA config into it. We then derive a `set`
+ of modules to exclude by investigating its denoiser state dict and denoiser LoRA
+ state dict.
+
+ We then create a new LoRA config to include the `exclude_modules` and perform tests.
+ """
+ scheduler_cls = self.scheduler_classes[0]
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
+ pipe = self.pipeline_class(**components).to(torch_device)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(output_no_lora.shape == self.output_shape)
+
+ # only supported for `denoiser` now
+ pipe_cp = copy.deepcopy(pipe)
+ pipe_cp, _ = self.add_adapters_to_pipeline(
+ pipe_cp, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ denoiser_exclude_modules = self._get_exclude_modules(pipe_cp)
+ pipe_cp.to("cpu")
+ del pipe_cp
+
+ denoiser_lora_config.exclude_modules = denoiser_exclude_modules
+ pipe, _ = self.add_adapters_to_pipeline(
+ pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
+ )
+ output_lora_exclude_modules = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
+ self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
+ pipe.unload_lora_weights()
+ pipe.load_lora_weights(tmpdir)
+
+ output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(
+ not np.allclose(output_no_lora, output_lora_exclude_modules, atol=1e-3, rtol=1e-3),
+ "LoRA should change outputs.",
+ )
+ self.assertTrue(
+ np.allclose(output_lora_exclude_modules, output_lora_pretrained, atol=1e-3, rtol=1e-3),
+ "Lora outputs should match.",
+ )
+
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
@@ -2355,3 +2441,104 @@ class PeftLoraLoaderMixinTests:
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
+
+ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
+
+ onload_device = torch_device
+ offload_device = torch.device("cpu")
+
+ components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+ self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
+
+ components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ check_if_lora_correctly_set(denoiser)
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+
+ # Test group offloading with load_lora_weights
+ denoiser.enable_group_offload(
+ onload_device=onload_device,
+ offload_device=offload_device,
+ offload_type=offload_type,
+ num_blocks_per_group=1,
+ use_stream=use_stream,
+ )
+ group_offload_hook_1 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_1 is not None)
+ output_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ # Test group offloading after removing the lora
+ pipe.unload_lora_weights()
+ group_offload_hook_2 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_2 is not None)
+ output_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] # noqa: F841
+
+ # Add the lora again and check if group offloading works
+ pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
+ check_if_lora_correctly_set(denoiser)
+ group_offload_hook_3 = _get_top_level_group_offload_hook(denoiser)
+ self.assertTrue(group_offload_hook_3 is not None)
+ output_3 = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ self.assertTrue(np.allclose(output_1, output_3, atol=1e-3, rtol=1e-3))
+
+ @parameterized.expand([("block_level", True), ("leaf_level", False), ("leaf_level", True)])
+ @require_torch_accelerator
+ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
+ for cls in inspect.getmro(self.__class__):
+ if "test_group_offloading_inference_denoiser" in cls.__dict__ and cls is not PeftLoraLoaderMixinTests:
+ # Skip this test if it is overwritten by child class. We need to do this because parameterized
+ # materializes the test methods on invocation which cannot be overridden.
+ return
+ self._test_group_offloading_inference_denoiser(offload_type, use_stream)
+
+ @require_torch_accelerator
+ def test_lora_loading_model_cpu_offload(self):
+ components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ _, _, inputs = self.get_dummy_inputs(with_generator=False)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ denoiser.add_adapter(denoiser_lora_config)
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
+ lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
+ self.pipeline_class.save_lora_weights(
+ save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
+ )
+ # reinitialize the pipeline to mimic the inference workflow.
+ components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
+ pipe = self.pipeline_class(**components)
+ pipe.enable_model_cpu_offload(device=torch_device)
+ pipe.load_lora_weights(tmpdirname)
+ denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
+ self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
+
+ output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
+ self.assertTrue(np.allclose(output_lora, output_lora_loaded, atol=1e-3, rtol=1e-3))
diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py
index e8b41ddbfd..36b563ba9f 100644
--- a/tests/models/test_modeling_common.py
+++ b/tests/models/test_modeling_common.py
@@ -61,6 +61,7 @@ from diffusers.utils import (
from diffusers.utils.hub_utils import _add_variant
from diffusers.utils.testing_utils import (
CaptureLogger,
+ _check_safetensors_serialization,
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
@@ -74,8 +75,8 @@ from diffusers.utils.testing_utils import (
require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training,
- require_torch_gpu,
require_torch_multi_accelerator,
+ require_torch_version_greater,
run_test_in_subprocess,
slow,
torch_all_close,
@@ -1349,7 +1350,6 @@ class ModelTesterMixin:
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
- print(f" new_model.hf_device_map:{new_model.hf_device_map}")
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
@@ -1528,21 +1528,24 @@ class ModelTesterMixin:
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
+ @torch.no_grad()
def test_layerwise_casting_inference(self):
- from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
+ from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN
torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
- model = self.model_class(**config).eval()
- model = model.to(torch_device)
- base_slice = model(**inputs_dict)[0].flatten().detach().cpu().numpy()
+ model = self.model_class(**config)
+ model.eval()
+ model.to(torch_device)
+ base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
def check_linear_dtype(module, storage_dtype, compute_dtype):
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
if getattr(module, "_skip_layerwise_casting_patterns", None) is not None:
patterns_to_check += tuple(module._skip_layerwise_casting_patterns)
for name, submodule in module.named_modules():
- if not isinstance(submodule, SUPPORTED_PYTORCH_LAYERS):
+ if not isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
continue
dtype_to_check = storage_dtype
if any(re.search(pattern, name) for pattern in patterns_to_check):
@@ -1573,6 +1576,7 @@ class ModelTesterMixin:
test_layerwise_casting(torch.float8_e4m3fn, torch.bfloat16)
@require_torch_accelerator
+ @torch.no_grad()
def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
@@ -1699,22 +1703,43 @@ class ModelTesterMixin:
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
_ = model(**inputs_dict)[0]
- @parameterized.expand([(False, "block_level"), (True, "leaf_level")])
+ @parameterized.expand([("block_level", False), ("leaf_level", True)])
@require_torch_accelerator
@torch.no_grad()
- def test_group_offloading_with_disk(self, record_stream, offload_type):
+ @torch.inference_mode()
+ def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
- torch.manual_seed(0)
+ def _has_generator_arg(model):
+ sig = inspect.signature(model.forward)
+ params = sig.parameters
+ return "generator" in params
+
+ def _run_forward(model, inputs_dict):
+ accepts_generator = _has_generator_arg(model)
+ if accepts_generator:
+ inputs_dict["generator"] = torch.manual_seed(0)
+ torch.manual_seed(0)
+ return model(**inputs_dict)[0]
+
+ if self.__class__.__name__ == "AutoencoderKLCosmosTests" and offload_type == "leaf_level":
+ pytest.skip("With `leaf_type` as the offloading type, it fails. Needs investigation.")
+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+ torch.manual_seed(0)
model = self.model_class(**init_dict)
+ model.eval()
+ model.to(torch_device)
+ output_without_group_offloading = _run_forward(model, inputs_dict)
+
torch.manual_seed(0)
- init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.eval()
- additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
+
+ num_blocks_per_group = None if offload_type == "leaf_level" else 1
+ additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
with tempfile.TemporaryDirectory() as tmpdir:
model.enable_group_offload(
torch_device,
@@ -1725,8 +1750,25 @@ class ModelTesterMixin:
**additional_kwargs,
)
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
- assert has_safetensors, "No safetensors found in the directory."
- _ = model(**inputs_dict)[0]
+ self.assertTrue(has_safetensors, "No safetensors found in the directory.")
+
+ # For "leaf-level", there is a prefetching hook which makes this check a bit non-deterministic
+ # in nature. So, skip it.
+ if offload_type != "leaf_level":
+ is_correct, extra_files, missing_files = _check_safetensors_serialization(
+ module=model,
+ offload_to_disk_path=tmpdir,
+ offload_type=offload_type,
+ num_blocks_per_group=num_blocks_per_group,
+ )
+ if not is_correct:
+ if extra_files:
+ raise ValueError(f"Found extra files: {', '.join(extra_files)}")
+ elif missing_files:
+ raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
+
+ output_with_group_offloading = _run_forward(model, inputs_dict)
+ self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
def test_auto_model(self, expected_max_diff=5e-5):
if self.forward_requires_fresh_args:
@@ -1787,8 +1829,8 @@ class ModelTesterMixin:
assert msg_substring in str(err_ctx.exception)
- @parameterized.expand([0, "cuda", torch.device("cuda")])
- @require_torch_gpu
+ @parameterized.expand([0, torch_device, torch.device(torch_device)])
+ @require_torch_accelerator
def test_passing_non_dict_device_map_works(self, device_map):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).eval()
@@ -1797,8 +1839,8 @@ class ModelTesterMixin:
loaded_model = self.model_class.from_pretrained(tmpdir, device_map=device_map)
_ = loaded_model(**inputs_dict)
- @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))])
- @require_torch_gpu
+ @parameterized.expand([("", torch_device), ("", torch.device(torch_device))])
+ @require_torch_accelerator
def test_passing_dict_device_map_works(self, name, device):
# There are other valid dict-based `device_map` values too. It's best to refer to
# the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap.
@@ -1903,11 +1945,14 @@ class ModelPushToHubTester(unittest.TestCase):
delete_repo(self.repo_id, token=TOKEN)
-@require_torch_gpu
+@require_torch_accelerator
@require_torch_2
@is_torch_compile
@slow
+@require_torch_version_greater("2.7.1")
class TorchCompileTesterMixin:
+ different_shapes_for_compilation = None
+
def setUp(self):
# clean up the VRAM before each test
super().setUp()
@@ -1936,19 +1981,40 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
_ = model(**inputs_dict)
+ def test_torch_compile_repeated_blocks(self):
+ if self.model_class._repeated_blocks is None:
+ pytest.skip("Skipping test as the model class doesn't have `_repeated_blocks` set.")
+
+ init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
+
+ model = self.model_class(**init_dict).to(torch_device)
+ model.compile_repeated_blocks(fullgraph=True)
+
+ recompile_limit = 1
+ if self.model_class.__name__ == "UNet2DConditionModel":
+ recompile_limit = 2
+
+ with (
+ torch._inductor.utils.fresh_inductor_cache(),
+ torch._dynamo.config.patch(recompile_limit=recompile_limit),
+ torch.no_grad(),
+ ):
+ _ = model(**inputs_dict)
+ _ = model(**inputs_dict)
+
def test_compile_with_group_offloading(self):
+ if not self.model_class._supports_group_offloading:
+ pytest.skip("Model does not support group offloading.")
+
torch._dynamo.config.cache_size_limit = 10000
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
- if not getattr(model, "_supports_group_offloading", True):
- return
-
model.eval()
# TODO: Can test for other group offloading kwargs later if needed.
group_offload_kwargs = {
- "onload_device": "cuda",
+ "onload_device": torch_device,
"offload_device": "cpu",
"offload_type": "block_level",
"num_blocks_per_group": 1,
@@ -1961,12 +2027,28 @@ class TorchCompileTesterMixin:
_ = model(**inputs_dict)
_ = model(**inputs_dict)
+ @require_torch_version_greater("2.7.1")
+ def test_compile_on_different_shapes(self):
+ if self.different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ torch.fx.experimental._config.use_duck_shape = False
+
+ init_dict, _ = self.prepare_init_args_and_inputs_for_common()
+ model = self.model_class(**init_dict).to(torch_device)
+ model = torch.compile(model, fullgraph=True, dynamic=True)
+
+ for height, width in self.different_shapes_for_compilation:
+ with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
+ inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**inputs_dict)
+
@slow
@require_torch_2
@require_torch_accelerator
@require_peft_backend
@require_peft_version_greater("0.14.0")
+@require_torch_version_greater("2.7.1")
@is_torch_compile
class LoraHotSwappingForModelTesterMixin:
"""Test that hotswapping does not result in recompilation on the model directly.
@@ -1981,6 +2063,8 @@ class LoraHotSwappingForModelTesterMixin:
"""
+ different_shapes_for_compilation = None
+
def tearDown(self):
# It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model,
# there will be recompilation errors, as torch caches the model when run in the same process.
@@ -2018,11 +2102,13 @@ class LoraHotSwappingForModelTesterMixin:
- hotswap the second adapter
- check that the outputs are correct
- optionally compile the model
+ - optionally check if recompilations happen on different shapes
Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would
fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is
fine.
"""
+ different_shapes = self.different_shapes_for_compilation
# create 2 adapters with different ranks and alphas
torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -2072,19 +2158,30 @@ class LoraHotSwappingForModelTesterMixin:
model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None)
if do_compile:
- model = torch.compile(model, mode="reduce-overhead")
+ model = torch.compile(model, mode="reduce-overhead", dynamic=different_shapes is not None)
with torch.inference_mode():
- output0_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
+ # additionally check if dynamic compilation works.
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output0_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol)
# hotswap the 2nd adapter
model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None)
# we need to call forward to potentially trigger recompilation
with torch.inference_mode():
- output1_after = model(**inputs_dict)["sample"]
- assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
+ if different_shapes is not None:
+ for height, width in different_shapes:
+ new_inputs_dict = self.prepare_dummy_input(height=height, width=width)
+ _ = model(**new_inputs_dict)
+ else:
+ output1_after = model(**inputs_dict)["sample"]
+ assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol)
# check error when not passing valid adapter name
name = "does-not-exist"
@@ -2108,7 +2205,7 @@ class LoraHotSwappingForModelTesterMixin:
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
- return
+ pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["conv", "conv1", "conv2"]
@@ -2118,7 +2215,7 @@ class LoraHotSwappingForModelTesterMixin:
@parameterized.expand([(11, 11), (7, 13), (13, 7)]) # important to test small to large and vice versa
def test_hotswapping_compiled_model_both_linear_and_conv2d(self, rank0, rank1):
if "unet" not in self.model_class.__name__.lower():
- return
+ pytest.skip("Test only applies to UNet.")
# It's important to add this context to raise an error on recompilation
target_modules = ["to_q", "conv"]
@@ -2202,3 +2299,23 @@ class LoraHotSwappingForModelTesterMixin:
do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1
)
assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output)
+
+ @parameterized.expand([(11, 11), (7, 13), (13, 7)])
+ @require_torch_version_greater("2.7.1")
+ def test_hotswapping_compile_on_different_shapes(self, rank0, rank1):
+ different_shapes_for_compilation = self.different_shapes_for_compilation
+ if different_shapes_for_compilation is None:
+ pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.")
+ # Specifying `use_duck_shape=False` instructs the compiler if it should use the same symbolic
+ # variable to represent input sizes that are the same. For more details,
+ # check out this [comment](https://github.com/huggingface/diffusers/pull/11327#discussion_r2047659790).
+ torch.fx.experimental._config.use_duck_shape = False
+
+ target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
+ with torch._dynamo.config.patch(error_on_recompile=True):
+ self.check_model_hotswap(
+ do_compile=True,
+ rank0=rank0,
+ rank1=rank1,
+ target_modules0=target_modules,
+ )
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 0a55236ef1..68b5c02bc0 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -91,10 +91,20 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
@property
def dummy_input(self):
+ return self.prepare_dummy_input()
+
+ @property
+ def input_shape(self):
+ return (16, 4)
+
+ @property
+ def output_shape(self):
+ return (16, 4)
+
+ def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
num_latent_channels = 4
num_image_channels = 3
- height = width = 4
sequence_length = 48
embedding_dim = 32
@@ -114,14 +124,6 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
"timestep": timestep,
}
- @property
- def input_shape(self):
- return (16, 4)
-
- @property
- def output_shape(self):
- return (16, 4)
-
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 1,
@@ -173,13 +175,21 @@ class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
+
class FluxTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
+ different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return FluxTransformerTests().prepare_init_args_and_inputs_for_common()
+
+ def prepare_dummy_input(self, height, width):
+ return FluxTransformerTests().prepare_dummy_input(height=height, width=width)
diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py
new file mode 100644
index 0000000000..884f168308
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py
@@ -0,0 +1,84 @@
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import SkyReelsV2Transformer3DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
+ model_class = SkyReelsV2Transformer3DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 1
+ num_channels = 4
+ num_frames = 2
+ height = 16
+ width = 16
+ text_encoder_embedding_dim = 16
+ sequence_length = 12
+
+ hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
+ timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 1, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 1, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "patch_size": (1, 2, 2),
+ "num_attention_heads": 2,
+ "attention_head_dim": 12,
+ "in_channels": 4,
+ "out_channels": 4,
+ "text_dim": 16,
+ "freq_dim": 256,
+ "ffn_dim": 32,
+ "num_layers": 2,
+ "cross_attn_norm": True,
+ "qk_norm": "rms_norm_across_heads",
+ "rope_max_seq_len": 32,
+ }
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"SkyReelsV2Transformer3DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py
index abf44aa744..123dff16f8 100644
--- a/tests/models/unets/test_models_unet_2d_condition.py
+++ b/tests/models/unets/test_models_unet_2d_condition.py
@@ -358,7 +358,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
- model_split_percents = [0.5, 0.3, 0.4]
+ model_split_percents = [0.5, 0.34, 0.4]
@property
def dummy_input(self):
diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py
deleted file mode 100644
index 94759d1f20..0000000000
--- a/tests/pipelines/amused/test_amused.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=8,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "height": 4,
- "width": 4,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.4011, 0.3992, 0.379, 0.3856, 0.3772, 0.3711, 0.3919, 0.385, 0.3625])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_256_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158])
- assert np.abs(image_slice - expected_slice).max() < 0.007
-
- def test_amused_512(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1199, 0.1171, 0.1229, 0.1188, 0.1210, 0.1147, 0.1260, 0.1346, 0.1152])
- assert np.abs(image_slice - expected_slice).max() < 0.003
-
- def test_amused_512_fp16(self):
- pipe = AmusedPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1509, 0.1492, 0.1531, 0.1485, 0.1501, 0.1465, 0.1581, 0.1690, 0.1499])
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py
deleted file mode 100644
index a76d82a2f0..0000000000
--- a/tests/pipelines/amused/test_amused_img2img.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedImg2ImgPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=8,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedImg2ImgPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.999, 0.9954, 1.0])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_256_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-256", torch_dtype=torch.float16, variant="fp16")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.998, 0.998, 0.994, 0.9944, 0.996, 0.9908, 1.0, 1.0, 0.9986])
- assert np.abs(image_slice - expected_slice).max() < 0.01
-
- def test_amused_512(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2809, 0.1879, 0.2027, 0.2418, 0.1852, 0.2145, 0.2484, 0.2425, 0.2317])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512_fp16(self):
- pipe = AmusedImg2ImgPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- image = pipe(
- "winter mountains",
- image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.2795, 0.1867, 0.2028, 0.2450, 0.1856, 0.2140, 0.2473, 0.2406, 0.2313])
- assert np.abs(image_slice - expected_slice).max() < 0.1
diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py
deleted file mode 100644
index 62f39de8c3..0000000000
--- a/tests/pipelines/amused/test_amused_inpaint.py
+++ /dev/null
@@ -1,250 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AmusedInpaintPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- transformer = UVit2DModel(
- hidden_size=8,
- use_bias=False,
- hidden_dropout=0.0,
- cond_embed_dim=8,
- micro_cond_encode_dim=2,
- micro_cond_embed_dim=10,
- encoder_hidden_size=8,
- vocab_size=32,
- codebook_size=32,
- in_channels=8,
- block_out_channels=8,
- num_res_blocks=1,
- downsample=True,
- upsample=True,
- block_num_heads=1,
- num_hidden_layers=1,
- num_attention_heads=1,
- attention_dropout=0.0,
- intermediate_size=8,
- layer_norm_eps=1e-06,
- ln_elementwise_affine=True,
- )
- scheduler = AmusedScheduler(mask_token_id=31)
- torch.manual_seed(0)
- vqvae = VQModel(
- act_fn="silu",
- block_out_channels=[8],
- down_block_types=["DownEncoderBlock2D"],
- in_channels=3,
- latent_channels=8,
- layers_per_block=1,
- norm_num_groups=8,
- num_vq_embeddings=32,
- out_channels=3,
- sample_size=8,
- up_block_types=["UpDecoderBlock2D"],
- mid_block_add_attention=False,
- lookup_from_codebook=True,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=8,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- components = {
- "transformer": transformer,
- "scheduler": scheduler,
- "vqvae": vqvae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device)
- mask_image[0, 0, 0, 0] = 0
- mask_image[0, 0, 0, 1] = 0
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "output_type": "np",
- "image": image,
- "mask_image": mask_image,
- }
- return inputs
-
- def test_inference_batch_consistent(self, batch_sizes=[2]):
- self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False)
-
- @unittest.skip("aMUSEd does not support lists of generators")
- def test_inference_batch_single_identical(self): ...
-
-
-@slow
-@require_torch_accelerator
-class AmusedInpaintPipelineSlowTests(unittest.TestCase):
- def test_amused_256(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_256_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-256", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((256, 256))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((256, 256))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 256, 256, 3)
- expected_slice = np.array([0.0735, 0.0749, 0.065, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622])
- assert np.abs(image_slice - expected_slice).max() < 0.1
-
- def test_amused_512(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512")
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0])
- assert np.abs(image_slice - expected_slice).max() < 0.05
-
- def test_amused_512_fp16(self):
- pipe = AmusedInpaintPipeline.from_pretrained("amused/amused-512", variant="fp16", torch_dtype=torch.float16)
- pipe.to(torch_device)
- image = (
- load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg")
- .resize((512, 512))
- .convert("RGB")
- )
- mask_image = (
- load_image(
- "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png"
- )
- .resize((512, 512))
- .convert("L")
- )
- image = pipe(
- "winter mountains",
- image,
- mask_image,
- generator=torch.Generator().manual_seed(0),
- num_inference_steps=2,
- output_type="np",
- ).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0227, 0.0157, 0.0098, 0.0213, 0.0250, 0.0127, 0.0280, 0.0380, 0.0095])
- assert np.abs(image_slice - expected_slice).max() < 0.003
diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py
deleted file mode 100644
index eb4139f0dc..0000000000
--- a/tests/pipelines/audioldm/test_audioldm.py
+++ /dev/null
@@ -1,461 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from transformers import (
- ClapTextConfig,
- ClapTextModelWithProjection,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AudioLDMPipeline,
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = AudioLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- layers_per_block=1,
- norm_num_groups=8,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(8, 16),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=8,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[8, 16],
- in_channels=1,
- out_channels=1,
- norm_num_groups=8,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=1,
- num_hidden_layers=1,
- pad_token_id=1,
- vocab_size=1000,
- projection_dim=8,
- )
- text_encoder = ClapTextModelWithProjection(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
-
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_audioldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0050, 0.0050, -0.0060, 0.0033, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0033]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = audioldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- prompt_embeds = prompt_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- prompt_embeds = F.normalize(prompt_embeds, dim=-1)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = audioldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=audioldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = audioldm_pipe.text_encoder(
- text_inputs,
- )
- text_embeds = text_embeds.text_embeds
- # additional L_2 normalization over each hidden-state
- text_embeds = F.normalize(text_embeds, dim=-1)
-
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = audioldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_audioldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = audioldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0051, 0.0050, -0.0060, 0.0034, -0.0026, 0.0033, -0.0027, 0.0033, -0.0028, 0.0032]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-2
-
- def test_audioldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = audioldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = audioldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = audioldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = audioldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_audioldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = audioldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = audioldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = audioldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_audioldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- audioldm_pipe = AudioLDMPipeline(**components)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = audioldm_pipe.vocoder.config
- config.model_in_dim *= 2
- audioldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = audioldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
-
-@nightly
-class AudioLDMPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[77230:77240]
- expected_slice = np.array(
- [-0.4884, -0.4607, 0.0023, 0.5007, 0.5896, 0.5151, 0.3813, -0.0208, -0.3687, -0.4315]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-2
-
-
-@nightly
-class AudioLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_audioldm_lms(self):
- audioldm_pipe = AudioLDMPipeline.from_pretrained("cvssp/audioldm")
- audioldm_pipe.scheduler = LMSDiscreteScheduler.from_config(audioldm_pipe.scheduler.config)
- audioldm_pipe = audioldm_pipe.to(torch_device)
- audioldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = audioldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81920
-
- audio_slice = audio[27780:27790]
- expected_slice = np.array([-0.2131, -0.0873, -0.0124, -0.0189, 0.0569, 0.1373, 0.1883, 0.2886, 0.3297, 0.2212])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 3e-2
diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py
deleted file mode 100644
index 0e3f723fc6..0000000000
--- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py
+++ /dev/null
@@ -1,204 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.utils.testing_utils import enable_full_determinism
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=8,
- intermediate_size=8,
- projection_dim=8,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(8,),
- norm_num_groups=8,
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- sample_size=8,
- )
-
- blip_vision_config = {
- "hidden_size": 8,
- "intermediate_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 8,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 8,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 8,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=8,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(8, 16),
- norm_num_groups=8,
- layers_per_block=1,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_blipdiffusion(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
-
- expected_slice = np.array(
- [0.5329548, 0.8372512, 0.33269387, 0.82096875, 0.43657133, 0.3783, 0.5953028, 0.51934963, 0.42142007]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
- )
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/chroma/test_pipeline_chroma.py b/tests/pipelines/chroma/test_pipeline_chroma.py
index fc5749f96c..5121a2b52d 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma.py
@@ -7,12 +7,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import torch_device
-from ..test_pipelines_common import (
- FluxIPAdapterTesterMixin,
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaPipelineFastTests(
@@ -126,12 +121,10 @@ class ChromaPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
index 02b20527b2..d518e1b7b8 100644
--- a/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
+++ b/tests/pipelines/chroma/test_pipeline_chroma_img2img.py
@@ -8,12 +8,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler
from diffusers.utils.testing_utils import floats_tensor, torch_device
-from ..test_pipelines_common import (
- FluxIPAdapterTesterMixin,
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist
class ChromaImg2ImgPipelineFastTests(
@@ -129,12 +124,10 @@ class ChromaImg2ImgPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/cogvideo/test_cogvideox.py b/tests/pipelines/cogvideo/test_cogvideox.py
index c725589781..a6cb558513 100644
--- a/tests/pipelines/cogvideo/test_cogvideox.py
+++ b/tests/pipelines/cogvideo/test_cogvideox.py
@@ -33,6 +33,7 @@ from diffusers.utils.testing_utils import (
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
check_qkv_fusion_matches_attn_procs_length,
@@ -45,7 +46,11 @@ enable_full_determinism()
class CogVideoXPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = CogVideoXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
deleted file mode 100644
index 100082b6f0..0000000000
--- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py
+++ /dev/null
@@ -1,228 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTokenizer
-from transformers.models.blip_2.configuration_blip_2 import Blip2Config
-from transformers.models.clip.configuration_clip import CLIPTextConfig
-
-from diffusers import (
- AutoencoderKL,
- BlipDiffusionControlNetPipeline,
- ControlNetModel,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
-from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
-from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = BlipDiffusionControlNetPipeline
- params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- batch_params = [
- "prompt",
- "reference_image",
- "source_subject_category",
- "target_subject_category",
- "condtioning_image",
- ]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "guidance_scale",
- "num_inference_steps",
- "neg_prompt",
- "guidance_scale",
- "prompt_strength",
- "prompt_reps",
- ]
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- vocab_size=1000,
- hidden_size=16,
- intermediate_size=16,
- projection_dim=16,
- num_hidden_layers=1,
- num_attention_heads=1,
- max_position_embeddings=77,
- )
- text_encoder = ContextCLIPTextModel(text_encoder_config)
-
- vae = AutoencoderKL(
- in_channels=4,
- out_channels=4,
- down_block_types=("DownEncoderBlock2D",),
- up_block_types=("UpDecoderBlock2D",),
- block_out_channels=(32,),
- layers_per_block=1,
- act_fn="silu",
- latent_channels=4,
- norm_num_groups=16,
- sample_size=16,
- )
-
- blip_vision_config = {
- "hidden_size": 16,
- "intermediate_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "image_size": 224,
- "patch_size": 14,
- "hidden_act": "quick_gelu",
- }
-
- blip_qformer_config = {
- "vocab_size": 1000,
- "hidden_size": 16,
- "num_hidden_layers": 1,
- "num_attention_heads": 1,
- "intermediate_size": 16,
- "max_position_embeddings": 512,
- "cross_attention_frequency": 1,
- "encoder_hidden_size": 16,
- }
- qformer_config = Blip2Config(
- vision_config=blip_vision_config,
- qformer_config=blip_qformer_config,
- num_query_tokens=16,
- tokenizer="hf-internal-testing/tiny-random-bert",
- )
- qformer = Blip2QFormerModel(qformer_config)
-
- unet = UNet2DConditionModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- norm_num_groups=4,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=16,
- )
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- scheduler = PNDMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- set_alpha_to_one=False,
- skip_prk_steps=True,
- )
- controlnet = ControlNetModel(
- block_out_channels=(4, 16),
- layers_per_block=1,
- in_channels=4,
- norm_num_groups=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- cross_attention_dim=16,
- conditioning_embedding_out_channels=(8, 16),
- )
-
- vae.eval()
- qformer.eval()
- text_encoder.eval()
-
- image_processor = BlipImageProcessor()
-
- components = {
- "text_encoder": text_encoder,
- "vae": vae,
- "qformer": qformer,
- "unet": unet,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "controlnet": controlnet,
- "image_processor": image_processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- np.random.seed(seed)
- reference_image = np.random.rand(32, 32, 3) * 255
- reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
- cond_image = np.random.rand(32, 32, 3) * 255
- cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "swimming underwater",
- "generator": generator,
- "reference_image": reference_image,
- "condtioning_image": cond_image,
- "source_subject_category": "dog",
- "target_subject_category": "dog",
- "height": 32,
- "width": 32,
- "guidance_scale": 7.5,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.4803, 0.3865, 0.1422, 0.6119, 0.2283, 0.6365, 0.5453, 0.5205, 0.3581])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_blipdiffusion_controlnet(self):
- device = "cpu"
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- image = pipe(**self.get_dummy_inputs(device))[0]
- image_slice = image[0, -3:, -3:, 0]
-
- assert image.shape == (1, 16, 16, 4)
- expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
-
- @unittest.skip("Test not supported because of complexities in deriving query_embeds.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
index 5ee94b09ba..5b336edc7a 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py
@@ -17,7 +17,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
@@ -211,7 +210,6 @@ class FluxControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMixin, Fl
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxControlNetPipeline
diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
index 8d63619c40..ab4cf32734 100644
--- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
+++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py
@@ -16,11 +16,7 @@ from diffusers.utils.testing_utils import (
)
from diffusers.utils.torch_utils import randn_tensor
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -170,12 +166,10 @@ class FluxControlNetImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMi
original_image_slice = image[0, -3:, -3:, -1]
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
index 712c26b0a2..1f1f800bcf 100644
--- a/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
+++ b/tests/pipelines/controlnet_sd3/test_controlnet_sd3.py
@@ -18,7 +18,6 @@ import unittest
from typing import Optional
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -221,7 +220,6 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline
diff --git a/tests/pipelines/controlnet_xs/__init__.py b/tests/pipelines/controlnet_xs/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs.py b/tests/pipelines/controlnet_xs/test_controlnetxs.py
deleted file mode 100644
index 6f8422797c..0000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs.py
+++ /dev/null
@@ -1,352 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- DDIMScheduler,
- LCMScheduler,
- StableDiffusionControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_image,
- require_accelerator,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class ControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- SDFunctionTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self, time_cond_proj_dim=None):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- norm_num_groups=4,
- time_cond_proj_dim=time_cond_proj_dim,
- use_linear_projection=True,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=1,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "numpy",
- "image": image,
- }
-
- return inputs
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- def test_controlnet_lcm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components(time_cond_proj_dim=8)
- sd_pipe = StableDiffusionControlNetXSPipeline(**components)
- sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 16, 16, 3)
- expected_slice = np.array([0.745, 0.753, 0.767, 0.543, 0.523, 0.502, 0.314, 0.521, 0.478])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@require_torch_accelerator
-class ControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (768, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.1963, 0.229, 0.2659, 0.2109, 0.2332, 0.2827, 0.2534, 0.2422, 0.2808])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SD2.1-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- output = pipe(prompt, image, generator=generator, output_type="np", num_inference_steps=3)
-
- image = output.images[0]
-
- assert image.shape == (512, 512, 3)
-
- original_image = image[-3:, -3:, -1].flatten()
- expected_image = np.array([0.4844, 0.4937, 0.4956, 0.4663, 0.5039, 0.5044, 0.4565, 0.4883, 0.4941])
- assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py b/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
deleted file mode 100644
index 24a8b9cd57..0000000000
--- a/tests/pipelines/controlnet_xs/test_controlnetxs_sdxl.py
+++ /dev/null
@@ -1,393 +0,0 @@
-# coding=utf-8
-# Copyright 2023 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import (
- AsymmetricAutoencoderKL,
- AutoencoderKL,
- AutoencoderTiny,
- ConsistencyDecoderVAE,
- ControlNetXSAdapter,
- EulerDiscreteScheduler,
- StableDiffusionXLControlNetXSPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.import_utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_image,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ...models.autoencoders.vae import (
- get_asym_autoencoder_kl_config,
- get_autoencoder_kl_config,
- get_autoencoder_tiny_config,
- get_consistency_vae_config,
-)
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionXLControlNetXSPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionXLControlNetXSPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- test_attention_slicing = False
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=16,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- use_linear_projection=True,
- norm_num_groups=4,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=56, # 6 * 8 (addition_time_embed_dim) + 8 (cross_attention_dim)
- cross_attention_dim=8,
- )
- torch.manual_seed(0)
- controlnet = ControlNetXSAdapter.from_unet(
- unet=unet,
- size_ratio=0.5,
- learn_time_embedding=True,
- conditioning_embedding_out_channels=(2, 2),
- )
- torch.manual_seed(0)
- scheduler = EulerDiscreteScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- steps_offset=1,
- beta_schedule="scaled_linear",
- timestep_spacing="leading",
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=8,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "controlnet": controlnet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "feature_extractor": None,
- }
- return components
-
- # Copied from test_controlnet_sdxl.py
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- controlnet_embedder_scale_factor = 2
- image = randn_tensor(
- (1, 3, 8 * controlnet_embedder_scale_factor, 8 * controlnet_embedder_scale_factor),
- generator=generator,
- device=torch.device(device),
- )
-
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "image": image,
- }
-
- return inputs
-
- # Copied from test_controlnet_sdxl.py
- def test_attention_slicing_forward_pass(self):
- return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- # Copied from test_controlnet_sdxl.py
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3)
-
- # Copied from test_controlnet_sdxl.py
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=2e-3)
-
- @unittest.skip("We test this functionality elsewhere already.")
- def test_save_load_optional_components(self):
- pass
-
- @require_torch_accelerator
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- pipe.unet.set_default_attn_processor()
-
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- # Copied from test_controlnet_sdxl.py
- def test_stable_diffusion_xl_multi_prompts(self):
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
-
- # forward with single prompt
- inputs = self.get_dummy_inputs(torch_device)
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = inputs["prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt_2"] = "different prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # manually set a negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- output = sd_pipe(**inputs)
- image_slice_1 = output.images[0, -3:, -3:, -1]
-
- # forward with same negative_prompt duplicated
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = inputs["negative_prompt"]
- output = sd_pipe(**inputs)
- image_slice_2 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are equal
- assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4
-
- # forward with different negative_prompt
- inputs = self.get_dummy_inputs(torch_device)
- inputs["negative_prompt"] = "negative prompt"
- inputs["negative_prompt_2"] = "different negative prompt"
- output = sd_pipe(**inputs)
- image_slice_3 = output.images[0, -3:, -3:, -1]
-
- # ensure the results are not equal
- assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
-
- # Copied from test_controlnetxs.py
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new UNetControlNetXSModel under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_multi_vae(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- block_out_channels = pipe.vae.config.block_out_channels
- norm_num_groups = pipe.vae.config.norm_num_groups
-
- vae_classes = [AutoencoderKL, AsymmetricAutoencoderKL, ConsistencyDecoderVAE, AutoencoderTiny]
- configs = [
- get_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_asym_autoencoder_kl_config(block_out_channels, norm_num_groups),
- get_consistency_vae_config(block_out_channels, norm_num_groups),
- get_autoencoder_tiny_config(block_out_channels),
- ]
-
- out_np = pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- for vae_cls, config in zip(vae_classes, configs):
- vae = vae_cls(**config)
- vae = vae.to(torch_device)
- components["vae"] = vae
- vae_pipe = self.pipeline_class(**components)
-
- # pipeline creates a new UNetControlNetXSModel under the hood, which aren't on device.
- # So we need to move the new pipe to device.
- vae_pipe.to(torch_device)
- vae_pipe.set_progress_bar_config(disable=None)
-
- out_vae_np = vae_pipe(**self.get_dummy_inputs_by_type(torch_device, input_image_type="np"))[0]
-
- assert out_vae_np.shape == out_np.shape
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_canny(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-canny", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "bird"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (768, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.3202, 0.3151, 0.3328, 0.3172, 0.337, 0.3381, 0.3378, 0.3389, 0.3224])
- assert np.allclose(original_image, expected_image, atol=1e-04)
-
- def test_depth(self):
- controlnet = ControlNetXSAdapter.from_pretrained(
- "UmerHA/Testing-ConrolNetXS-SDXL-depth", torch_dtype=torch.float16
- )
- pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
- )
- pipe.enable_sequential_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- prompt = "Stormtrooper's lecture"
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/stormtrooper_depth.png"
- )
-
- images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
-
- assert images[0].shape == (512, 512, 3)
-
- original_image = images[0, -3:, -3:, -1].flatten()
- expected_image = np.array([0.5448, 0.5437, 0.5426, 0.5543, 0.553, 0.5475, 0.5595, 0.5602, 0.5529])
- assert np.allclose(original_image, expected_image, atol=1e-04)
diff --git a/tests/pipelines/cosmos/test_cosmos.py b/tests/pipelines/cosmos/test_cosmos.py
index 0c1024a9a9..4d3202f785 100644
--- a/tests/pipelines/cosmos/test_cosmos.py
+++ b/tests/pipelines/cosmos/test_cosmos.py
@@ -153,11 +153,15 @@ class CosmosTextToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
- expected_video = torch.randn(9, 3, 32, 32)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.0, 0.9686, 0.8549, 0.8078, 0.0, 0.8431, 1.0, 0.4863, 0.7098, 0.1098, 0.8157, 0.4235, 0.6353, 0.2549, 0.5137, 0.5333])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/cosmos/test_cosmos2_text2image.py b/tests/pipelines/cosmos/test_cosmos2_text2image.py
index 386bf161a0..cc2fcec641 100644
--- a/tests/pipelines/cosmos/test_cosmos2_text2image.py
+++ b/tests/pipelines/cosmos/test_cosmos2_text2image.py
@@ -140,11 +140,15 @@ class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
-
self.assertEqual(generated_image.shape, (3, 32, 32))
- expected_video = torch.randn(3, 32, 32)
- max_diff = np.abs(generated_image - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.4784, 0.4784, 0.4784, 0.4784, 0.4784, 0.4902, 0.4588, 0.5333])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py
index 421e3a1ad3..b23c8aed17 100644
--- a/tests/pipelines/cosmos/test_cosmos2_video2world.py
+++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py
@@ -147,11 +147,15 @@ class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCas
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
- expected_video = torch.randn(9, 3, 32, 32)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.451, 0.451, 0.4471, 0.451, 0.451, 0.451, 0.451, 0.451, 0.5098, 0.5137, 0.5176, 0.5098, 0.5255, 0.5412, 0.5098, 0.5059])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
def test_components_function(self):
init_components = self.get_dummy_components()
diff --git a/tests/pipelines/cosmos/test_cosmos_video2world.py b/tests/pipelines/cosmos/test_cosmos_video2world.py
index 2b893e9970..d0dba5575b 100644
--- a/tests/pipelines/cosmos/test_cosmos_video2world.py
+++ b/tests/pipelines/cosmos/test_cosmos_video2world.py
@@ -159,11 +159,15 @@ class CosmosVideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
- expected_video = torch.randn(9, 3, 32, 32)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.0, 0.8275, 0.7529, 0.7294, 0.0, 0.6, 1.0, 0.3804, 0.6667, 0.0863, 0.8784, 0.5922, 0.6627, 0.2784, 0.5725, 0.7765])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
def test_components_function(self):
init_components = self.get_dummy_components()
diff --git a/tests/pipelines/dance_diffusion/__init__.py b/tests/pipelines/dance_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py
deleted file mode 100644
index a2a1753214..0000000000
--- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py
+++ /dev/null
@@ -1,174 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = DanceDiffusionPipeline
- params = UNCONDITIONAL_AUDIO_GENERATION_PARAMS
- required_optional_params = PipelineTesterMixin.required_optional_params - {
- "callback",
- "latents",
- "callback_steps",
- "output_type",
- "num_images_per_prompt",
- }
- batch_params = UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS
- test_attention_slicing = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet1DModel(
- block_out_channels=(32, 32, 64),
- extra_in_channels=16,
- sample_size=512,
- sample_rate=16_000,
- in_channels=2,
- out_channels=2,
- flip_sin_to_cos=True,
- use_timestep_embedding=False,
- time_embedding_type="fourier",
- mid_block_type="UNetMidBlock1D",
- down_block_types=("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"),
- up_block_types=("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"),
- )
- scheduler = IPNDMScheduler()
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "batch_size": 1,
- "generator": generator,
- "num_inference_steps": 4,
- }
- return inputs
-
- def test_dance_diffusion(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = DanceDiffusionPipeline(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = pipe(**inputs)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, components["unet"].sample_size)
- expected_slice = np.array([-0.7265, 1.0000, -0.8388, 0.1175, 0.9498, -1.0000])
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local()
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent(expected_max_difference=3e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- return super().test_attention_slicing_forward_pass()
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_accelerator
-class PipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_dance_diffusion(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k")
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0192, -0.0231, -0.0318, -0.0059, 0.0002, -0.0020])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_dance_diffusion_fp16(self):
- device = torch_device
-
- pipe = DanceDiffusionPipeline.from_pretrained("harmonai/maestro-150k", torch_dtype=torch.float16)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipe(generator=generator, num_inference_steps=100, audio_length_in_s=4.096)
- audio = output.audios
-
- audio_slice = audio[0, -3:, -3:]
-
- assert audio.shape == (1, 2, pipe.unet.config.sample_size)
- expected_slice = np.array([-0.0367, -0.0488, -0.0771, -0.0525, -0.0444, -0.0341])
-
- assert np.abs(audio_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py
index cbdf617d71..cc8266e1a5 100644
--- a/tests/pipelines/flux/test_pipeline_flux.py
+++ b/tests/pipelines/flux/test_pipeline_flux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -25,20 +24,21 @@ from diffusers.utils.testing_utils import (
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
+ check_qkv_fused_layers_exist,
)
class FluxPipelineFastTests(
- unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -154,7 +154,7 @@ class FluxPipelineFastTests(
# Outputs should be different here
# For some reasons, they don't show large differences
- assert max_diff > 1e-6
+ self.assertGreater(max_diff, 1e-6, "Outputs should be different for different prompts.")
def test_fused_qkv_projections(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@@ -170,12 +170,10 @@ class FluxPipelineFastTests(
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
@@ -186,14 +184,17 @@ class FluxPipelineFastTests(
image = pipe(**inputs).images
image_slice_disabled = image[0, -3:, -3:, -1]
- assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), (
- "Fusion of QKV projections shouldn't affect the outputs."
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
+ ("Fusion of QKV projections shouldn't affect the outputs."),
)
- assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), (
- "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."
+ self.assertTrue(
+ np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
+ ("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
)
- assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), (
- "Original outputs should match when fused QKV projections are disabled."
+ self.assertTrue(
+ np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
+ ("Original outputs should match when fused QKV projections are disabled."),
)
def test_flux_image_output_shape(self):
@@ -208,7 +209,11 @@ class FluxPipelineFastTests(
inputs.update({"height": height, "width": width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
- assert (output_height, output_width) == (expected_height, expected_width)
+ self.assertEqual(
+ (output_height, output_width),
+ (expected_height, expected_width),
+ f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
+ )
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
@@ -219,12 +224,13 @@ class FluxPipelineFastTests(
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
- assert not np.allclose(no_true_cfg_out, true_cfg_out)
+ self.assertFalse(
+ np.allclose(no_true_cfg_out, true_cfg_out), "Outputs should be different when true_cfg_scale is set."
+ )
@nightly
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell"
@@ -269,50 +275,21 @@ class FluxPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+ # fmt: off
expected_slice = np.array(
- [
- 0.3242,
- 0.3203,
- 0.3164,
- 0.3164,
- 0.3125,
- 0.3125,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3301,
- 0.3262,
- 0.3242,
- 0.3281,
- 0.3242,
- 0.3203,
- 0.3262,
- 0.3262,
- 0.3164,
- 0.3262,
- 0.3281,
- 0.3184,
- 0.3281,
- 0.3281,
- 0.3203,
- 0.3281,
- 0.3281,
- 0.3164,
- 0.3320,
- 0.3320,
- 0.3203,
- ],
+ [0.3242, 0.3203, 0.3164, 0.3164, 0.3125, 0.3125, 0.3281, 0.3242, 0.3203, 0.3301, 0.3262, 0.3242, 0.3281, 0.3242, 0.3203, 0.3262, 0.3262, 0.3164, 0.3262, 0.3281, 0.3184, 0.3281, 0.3281, 0.3203, 0.3281, 0.3281, 0.3164, 0.3320, 0.3320, 0.3203],
dtype=np.float32,
)
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4
+ self.assertLess(
+ max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
+ )
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
@@ -378,42 +355,14 @@ class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
+ # fmt: off
expected_slice = np.array(
- [
- 0.1855,
- 0.1680,
- 0.1406,
- 0.1953,
- 0.1699,
- 0.1465,
- 0.2012,
- 0.1738,
- 0.1484,
- 0.2051,
- 0.1797,
- 0.1523,
- 0.2012,
- 0.1719,
- 0.1445,
- 0.2070,
- 0.1777,
- 0.1465,
- 0.2090,
- 0.1836,
- 0.1484,
- 0.2129,
- 0.1875,
- 0.1523,
- 0.2090,
- 0.1816,
- 0.1484,
- 0.2110,
- 0.1836,
- 0.1543,
- ],
+ [0.1855, 0.1680, 0.1406, 0.1953, 0.1699, 0.1465, 0.2012, 0.1738, 0.1484, 0.2051, 0.1797, 0.1523, 0.2012, 0.1719, 0.1445, 0.2070, 0.1777, 0.1465, 0.2090, 0.1836, 0.1484, 0.2129, 0.1875, 0.1523, 0.2090, 0.1816, 0.1484, 0.2110, 0.1836, 0.1543],
dtype=np.float32,
)
+ # fmt: on
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
-
- assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
+ self.assertLess(
+ max_diff, 1e-4, f"Image slice is different from expected slice: {image_slice} != {expected_slice}"
+ )
diff --git a/tests/pipelines/flux/test_pipeline_flux_control.py b/tests/pipelines/flux/test_pipeline_flux_control.py
index d8d0774e1e..42283da6fd 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control.py
@@ -8,11 +8,7 @@ from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPToken
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxControlPipeline, FluxTransformer2DModel
from diffusers.utils.testing_utils import torch_device
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -140,12 +136,10 @@ class FluxControlPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
index a2f7c91710..0abd08e373 100644
--- a/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
+++ b/tests/pipelines/flux/test_pipeline_flux_control_inpaint.py
@@ -15,11 +15,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
-from ..test_pipelines_common import (
- PipelineTesterMixin,
- check_qkv_fusion_matches_attn_procs_length,
- check_qkv_fusion_processors_exist,
-)
+from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@@ -134,12 +130,10 @@ class FluxControlInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin
# TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added
# to the pipeline level.
pipe.transformer.fuse_qkv_projections()
- assert check_qkv_fusion_processors_exist(pipe.transformer), (
- "Something wrong with the fused attention processors. Expected all the attention processors to be fused."
+ self.assertTrue(
+ check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
+ ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
)
- assert check_qkv_fusion_matches_attn_procs_length(
- pipe.transformer, pipe.transformer.original_attn_processors
- ), "Something wrong with the attention processors concerning the fused QKV projections."
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext.py b/tests/pipelines/flux/test_pipeline_flux_kontext.py
new file mode 100644
index 0000000000..7471d78ad5
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext.py
@@ -0,0 +1,177 @@
+import unittest
+
+import numpy as np
+import PIL.Image
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import torch_device
+
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ image = PIL.Image.new("RGB", (32, 32), 0)
+ inputs = {
+ "image": image,
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 8,
+ "width": 8,
+ "max_area": 8 * 8,
+ "max_sequence_length": 48,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 57)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+
+ inputs.update({"height": height, "width": width, "max_area": height * width})
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
new file mode 100644
index 0000000000..615209264d
--- /dev/null
+++ b/tests/pipelines/flux/test_pipeline_flux_kontext_inpaint.py
@@ -0,0 +1,190 @@
+import random
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKL,
+ FasterCacheConfig,
+ FlowMatchEulerDiscreteScheduler,
+ FluxKontextInpaintPipeline,
+ FluxTransformer2DModel,
+)
+from diffusers.utils.testing_utils import floats_tensor, torch_device
+
+from ..test_pipelines_common import (
+ FasterCacheTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+)
+
+
+class FluxKontextInpaintPipelineFastTests(
+ unittest.TestCase,
+ PipelineTesterMixin,
+ FluxIPAdapterTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+):
+ pipeline_class = FluxKontextInpaintPipeline
+ params = frozenset(
+ ["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
+ )
+ batch_params = frozenset(["image", "prompt"])
+
+ # there is no xformers processor for Flux
+ test_xformers_attention = False
+ test_layerwise_casting = True
+ test_group_offloading = True
+
+ faster_cache_config = FasterCacheConfig(
+ spatial_attention_block_skip_range=2,
+ spatial_attention_timestep_skip_range=(-1, 901),
+ unconditional_batch_skip_range=2,
+ attention_weight_callback=lambda _: 0.5,
+ is_guidance_distilled=True,
+ )
+
+ def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
+ torch.manual_seed(0)
+ transformer = FluxTransformer2DModel(
+ patch_size=1,
+ in_channels=4,
+ num_layers=num_layers,
+ num_single_layers=num_single_layers,
+ attention_head_dim=16,
+ num_attention_heads=2,
+ joint_attention_dim=32,
+ pooled_projection_dim=32,
+ axes_dims_rope=[4, 4, 8],
+ )
+ clip_text_encoder_config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=32,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ hidden_act="gelu",
+ projection_dim=32,
+ )
+
+ torch.manual_seed(0)
+ text_encoder = CLIPTextModel(clip_text_encoder_config)
+
+ torch.manual_seed(0)
+ text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=1,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+
+ return {
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "text_encoder_2": text_encoder_2,
+ "tokenizer": tokenizer,
+ "tokenizer_2": tokenizer_2,
+ "transformer": transformer,
+ "vae": vae,
+ "image_encoder": None,
+ "feature_extractor": None,
+ }
+
+ def get_dummy_inputs(self, device, seed=0):
+ image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
+ mask_image = torch.ones((1, 1, 32, 32)).to(device)
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "image": image,
+ "mask_image": mask_image,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "max_sequence_length": 48,
+ "strength": 0.8,
+ "output_type": "np",
+ "_auto_resize": False,
+ }
+ return inputs
+
+ def test_flux_inpaint_different_prompts(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+
+ inputs = self.get_dummy_inputs(torch_device)
+ output_same_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs["prompt_2"] = "a different prompt"
+ output_different_prompts = pipe(**inputs).images[0]
+
+ max_diff = np.abs(output_same_prompt - output_different_prompts).max()
+
+ # Outputs should be different here
+ # For some reasons, they don't show large differences
+ assert max_diff > 1e-6
+
+ def test_flux_image_output_shape(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ height_width_pairs = [(32, 32), (72, 56)]
+ for height, width in height_width_pairs:
+ expected_height = height - height % (pipe.vae_scale_factor * 2)
+ expected_width = width - width % (pipe.vae_scale_factor * 2)
+ # Because output shape is the same as the input shape, we need to create a dummy image and mask image
+ image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device)
+ mask_image = torch.ones((1, 1, height, width)).to(torch_device)
+
+ inputs.update(
+ {
+ "height": height,
+ "width": width,
+ "max_area": height * width,
+ "image": image,
+ "mask_image": mask_image,
+ }
+ )
+ image = pipe(**inputs).images[0]
+ output_height, output_width, _ = image.shape
+ assert (output_height, output_width) == (expected_height, expected_width)
+
+ def test_flux_true_cfg(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+ inputs.pop("generator")
+
+ no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ inputs["negative_prompt"] = "bad quality"
+ inputs["true_cfg_scale"] = 2.0
+ true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
+ assert not np.allclose(no_true_cfg_out, true_cfg_out)
diff --git a/tests/pipelines/flux/test_pipeline_flux_redux.py b/tests/pipelines/flux/test_pipeline_flux_redux.py
index b8f36dfd3c..b73050a64d 100644
--- a/tests/pipelines/flux/test_pipeline_flux_redux.py
+++ b/tests/pipelines/flux/test_pipeline_flux_redux.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from diffusers import FluxPipeline, FluxPriorReduxPipeline
@@ -19,7 +18,6 @@ from diffusers.utils.testing_utils import (
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev"
diff --git a/tests/pipelines/hidream_image/test_pipeline_hidream.py b/tests/pipelines/hidream_image/test_pipeline_hidream.py
index ada4a11d16..1c5f30e870 100644
--- a/tests/pipelines/hidream_image/test_pipeline_hidream.py
+++ b/tests/pipelines/hidream_image/test_pipeline_hidream.py
@@ -146,11 +146,15 @@ class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
-
self.assertEqual(generated_image.shape, (128, 128, 3))
- expected_image = torch.randn(128, 128, 3).numpy()
- max_diff = np.abs(generated_image - expected_image).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = np.array([0.4507, 0.5256, 0.4205, 0.5791, 0.4848, 0.4831, 0.4443, 0.5107, 0.6586, 0.3163, 0.7318, 0.5933, 0.6252, 0.5512, 0.5357, 0.5983])
+ # fmt: on
+
+ generated_slice = generated_image.flatten()
+ generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(np.allclose(generated_slice, expected_slice, atol=1e-3))
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-4)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
index 6a4e3a8931..82281f28bc 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_image2video.py
@@ -229,12 +229,19 @@ class HunyuanVideoImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
# NOTE: The expected video has 4 lesser frames because they are dropped in the pipeline
self.assertEqual(generated_video.shape, (5, 3, 16, 16))
- expected_video = torch.randn(5, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.444, 0.479, 0.4485, 0.5752, 0.3539, 0.1548, 0.2706, 0.3593, 0.5323, 0.6635, 0.6795, 0.5255, 0.5091, 0.345, 0.4276, 0.4128])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
index 94d3c3739f..fad159c06b 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_skyreels_image2video.py
@@ -192,11 +192,18 @@ class HunyuanSkyreelsImageToVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.5832, 0.5498, 0.4839, 0.4744, 0.4515, 0.4832, 0.496, 0.563, 0.5918, 0.5979, 0.5101, 0.6168, 0.6613, 0.536, 0.55, 0.5775])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video.py b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
index ecc5eba964..26ec861522 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video.py
@@ -26,13 +26,11 @@ from diffusers import (
HunyuanVideoPipeline,
HunyuanVideoTransformer3DModel,
)
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- torch_device,
-)
+from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
to_np,
@@ -43,7 +41,11 @@ enable_full_determinism()
class HunyuanVideoPipelineFastTests(
- PipelineTesterMixin, PyramidAttentionBroadcastTesterMixin, FasterCacheTesterMixin, unittest.TestCase
+ PipelineTesterMixin,
+ PyramidAttentionBroadcastTesterMixin,
+ FasterCacheTesterMixin,
+ FirstBlockCacheTesterMixin,
+ unittest.TestCase,
):
pipeline_class = HunyuanVideoPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
@@ -201,11 +203,18 @@ class HunyuanVideoPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.3946, 0.4649, 0.3196, 0.4569, 0.3312, 0.3687, 0.3216, 0.3972, 0.4469, 0.3888, 0.3929, 0.3802, 0.3479, 0.3888, 0.3825, 0.3542])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
index 9f685d34c9..297c3df45a 100644
--- a/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
+++ b/tests/pipelines/hunyuan_video/test_hunyuan_video_framepack.py
@@ -227,11 +227,18 @@ class HunyuanVideoFramepackPipelineFastTests(
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (13, 3, 32, 32))
- expected_video = torch.randn(13, 3, 32, 32)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.363, 0.3384, 0.3426, 0.3512, 0.3372, 0.3276, 0.417, 0.4061, 0.5221, 0.467, 0.4813, 0.4556, 0.4107, 0.3945, 0.4049, 0.4551])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(
+ torch.allclose(generated_slice, expected_slice, atol=1e-3),
+ "The generated video does not match the expected slice.",
+ )
def test_callback_inputs(self):
sig = inspect.signature(self.pipeline_class.__call__)
diff --git a/tests/pipelines/i2vgen_xl/__init__.py b/tests/pipelines/i2vgen_xl/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
deleted file mode 100644
index bedd63738a..0000000000
--- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import pytest
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- I2VGenXLPipeline,
-)
-from diffusers.models.unets import I2VGenXLUNet
-from diffusers.utils import is_xformers_available, load_image
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- is_torch_version,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase):
- pipeline_class = I2VGenXLPipeline
- params = frozenset(["prompt", "negative_prompt", "image"])
- batch_params = frozenset(["prompt", "negative_prompt", "image", "generator"])
- # No `output_type`.
- required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"])
-
- supports_dduf = False
- test_layerwise_casting = True
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- torch.manual_seed(0)
- unet = I2VGenXLUNet(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- num_attention_heads=None,
- norm_num_groups=2,
- )
-
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- torch.manual_seed(0)
- vision_encoder_config = CLIPVisionConfig(
- hidden_size=4,
- projection_dim=4,
- num_hidden_layers=2,
- num_attention_heads=2,
- image_size=32,
- intermediate_size=16,
- patch_size=1,
- )
- image_encoder = CLIPVisionModelWithProjection(vision_encoder_config)
-
- torch.manual_seed(0)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "image": input_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- "num_frames": 4,
- "width": 32,
- "height": 32,
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @pytest.mark.xfail(
- condition=is_torch_version(">=", "2.7"),
- reason="Test currently fails on PyTorch 2.7.",
- strict=False,
- )
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=0.006)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent(expected_max_difference=0.009)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.008)
-
- @unittest.skip("Deprecated functionality")
- def test_attention_slicing_forward_pass(self):
- pass
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=0.008)
-
- def test_model_cpu_offload_forward_pass(self):
- super().test_model_cpu_offload_forward_pass(expected_max_diff=0.008)
-
- def test_num_videos_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = pipe(**inputs, num_videos_per_prompt=2).frames
-
- assert frames.shape == (2, 4, 32, 32, 3)
- assert frames[0][0].shape == (32, 32, 3)
-
- image_slice = frames[0][0][-3:, -3:, -1]
- expected_slice = np.array([0.5146, 0.6525, 0.6032, 0.5204, 0.5675, 0.4125, 0.3016, 0.5172, 0.4095])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skip("Test not supported for now.")
- def test_encode_prompt_works_in_isolation(self):
- pass
-
-
-@slow
-@require_torch_accelerator
-class I2VGenXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_i2vgen_xl(self):
- pipe = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
- )
-
- generator = torch.Generator("cpu").manual_seed(0)
- num_frames = 3
-
- output = pipe(
- image=image,
- prompt="my cat",
- num_frames=num_frames,
- generator=generator,
- num_inference_steps=3,
- output_type="np",
- )
-
- image = output.frames[0]
- assert image.shape == (num_frames, 704, 1280, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5482, 0.6244, 0.6274, 0.4584, 0.5935, 0.5937, 0.4579, 0.5767, 0.5892])
- assert numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice.flatten()) < 1e-3
diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
index 84085f9d7d..b2d6f0fc05 100644
--- a/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
+++ b/tests/pipelines/kandinsky2_2/test_kandinsky_controlnet.py
@@ -289,6 +289,5 @@ class KandinskyV22ControlnetPipelineIntegrationTests(unittest.TestCase):
image = output.images[0]
assert image.shape == (512, 512, 3)
-
max_diff = numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten())
- assert max_diff < 1e-4
+ assert max_diff < 2e-4
diff --git a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
index 342561d4f5..ab0221dc81 100644
--- a/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
+++ b/tests/pipelines/ledits_pp/test_ledits_pp_stable_diffusion.py
@@ -29,6 +29,7 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import (
+ Expectations,
backend_empty_cache,
enable_full_determinism,
floats_tensor,
@@ -244,7 +245,35 @@ class LEditsPPPipelineStableDiffusionSlowTests(unittest.TestCase):
output_slice = reconstruction[150:153, 140:143, -1]
output_slice = output_slice.flatten()
- expected_slice = np.array(
- [0.9453125, 0.93310547, 0.84521484, 0.94628906, 0.9111328, 0.80859375, 0.93847656, 0.9042969, 0.8144531]
+ expected_slices = Expectations(
+ {
+ ("xpu", 3): np.array(
+ [
+ 0.9511719,
+ 0.94140625,
+ 0.87597656,
+ 0.9472656,
+ 0.9296875,
+ 0.8378906,
+ 0.94433594,
+ 0.91503906,
+ 0.8491211,
+ ]
+ ),
+ ("cuda", 7): np.array(
+ [
+ 0.9453125,
+ 0.93310547,
+ 0.84521484,
+ 0.94628906,
+ 0.9111328,
+ 0.80859375,
+ 0.93847656,
+ 0.9042969,
+ 0.8144531,
+ ]
+ ),
+ }
)
+ expected_slice = expected_slices.get_expectation()
assert np.abs(output_slice - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/ltx/test_ltx.py b/tests/pipelines/ltx/test_ltx.py
index 1d1eb08234..bf0c7fde59 100644
--- a/tests/pipelines/ltx/test_ltx.py
+++ b/tests/pipelines/ltx/test_ltx.py
@@ -23,13 +23,13 @@ from diffusers import AutoencoderKLLTXVideo, FlowMatchEulerDiscreteScheduler, LT
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, to_np
+from ..test_pipelines_common import FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+class LTXPipelineFastTests(PipelineTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase):
pipeline_class = LTXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -49,7 +49,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
test_layerwise_casting = True
test_group_offloading = True
- def get_dummy_components(self):
+ def get_dummy_components(self, num_layers: int = 1):
torch.manual_seed(0)
transformer = LTXVideoTransformer3DModel(
in_channels=8,
@@ -59,7 +59,7 @@ class LTXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
num_attention_heads=4,
attention_head_dim=8,
cross_attention_dim=32,
- num_layers=1,
+ num_layers=num_layers,
caption_channels=32,
)
diff --git a/tests/pipelines/mochi/test_mochi.py b/tests/pipelines/mochi/test_mochi.py
index 5b00261b06..f1684cce72 100644
--- a/tests/pipelines/mochi/test_mochi.py
+++ b/tests/pipelines/mochi/test_mochi.py
@@ -17,7 +17,6 @@ import inspect
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -33,13 +32,15 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import FasterCacheTesterMixin, PipelineTesterMixin, to_np
+from ..test_pipelines_common import FasterCacheTesterMixin, FirstBlockCacheTesterMixin, PipelineTesterMixin, to_np
enable_full_determinism()
-class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unittest.TestCase):
+class MochiPipelineFastTests(
+ PipelineTesterMixin, FasterCacheTesterMixin, FirstBlockCacheTesterMixin, unittest.TestCase
+):
pipeline_class = MochiPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
@@ -268,7 +269,6 @@ class MochiPipelineFastTests(PipelineTesterMixin, FasterCacheTesterMixin, unitte
@nightly
@require_torch_accelerator
@require_big_accelerator
-@pytest.mark.big_accelerator
class MochiPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
diff --git a/tests/pipelines/musicldm/__init__.py b/tests/pipelines/musicldm/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py
deleted file mode 100644
index 5d6392865b..0000000000
--- a/tests/pipelines/musicldm/test_musicldm.py
+++ /dev/null
@@ -1,478 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- ClapAudioConfig,
- ClapConfig,
- ClapFeatureExtractor,
- ClapModel,
- ClapTextConfig,
- RobertaTokenizer,
- SpeechT5HifiGan,
- SpeechT5HifiGanConfig,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- LMSDiscreteScheduler,
- MusicLDMPipeline,
- PNDMScheduler,
- UNet2DConditionModel,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = MusicLDMPipeline
- params = TEXT_TO_AUDIO_PARAMS
- batch_params = TEXT_TO_AUDIO_BATCH_PARAMS
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "num_waveforms_per_prompt",
- "generator",
- "latents",
- "output_type",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=(32, 64),
- class_embed_type="simple_projection",
- projection_class_embeddings_input_dim=32,
- class_embeddings_concat=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=1,
- out_channels=1,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_branch_config = ClapTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=16,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- )
- audio_branch_config = ClapAudioConfig(
- spec_size=64,
- window_size=4,
- num_mel_bins=64,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- depths=[2, 2],
- num_attention_heads=[2, 2],
- num_hidden_layers=2,
- hidden_size=192,
- patch_size=2,
- patch_stride=2,
- patch_embed_input_channels=4,
- )
- text_encoder_config = ClapConfig.from_text_audio_configs(
- text_config=text_branch_config, audio_config=audio_branch_config, projection_dim=32
- )
- text_encoder = ClapModel(text_encoder_config)
- tokenizer = RobertaTokenizer.from_pretrained("hf-internal-testing/tiny-random-roberta", model_max_length=77)
- feature_extractor = ClapFeatureExtractor.from_pretrained(
- "hf-internal-testing/tiny-random-ClapModel", hop_length=7900
- )
-
- torch.manual_seed(0)
- vocoder_config = SpeechT5HifiGanConfig(
- model_in_dim=8,
- sampling_rate=16000,
- upsample_initial_channel=16,
- upsample_rates=[2, 2],
- upsample_kernel_sizes=[4, 4],
- resblock_kernel_sizes=[3, 7],
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5]],
- normalize_before=False,
- )
-
- vocoder = SpeechT5HifiGan(vocoder_config)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": feature_extractor,
- "vocoder": vocoder,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- }
- return inputs
-
- def test_musicldm_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(**inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0020, -0.0035, -0.0019, -0.0037, -0.0020, -0.0038, -0.0019]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = musicldm_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = musicldm_pipe.text_encoder.get_text_features(text_inputs)
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt_embeds(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- negative_prompt = 3 * ["this is a negative prompt"]
- inputs["negative_prompt"] = negative_prompt
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_1 = output.audios[0]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- embeds = []
- for p in [prompt, negative_prompt]:
- text_inputs = musicldm_pipe.tokenizer(
- p,
- padding="max_length",
- max_length=musicldm_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- text_embeds = musicldm_pipe.text_encoder.get_text_features(
- text_inputs,
- )
- embeds.append(text_embeds)
-
- inputs["prompt_embeds"], inputs["negative_prompt_embeds"] = embeds
-
- # forward
- output = musicldm_pipe(**inputs)
- audio_2 = output.audios[0]
-
- assert np.abs(audio_1 - audio_2).max() < 1e-2
-
- def test_musicldm_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "egg cracking"
- output = musicldm_pipe(**inputs, negative_prompt=negative_prompt)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 256
-
- audio_slice = audio[:10]
- expected_slice = np.array(
- [-0.0027, -0.0036, -0.0037, -0.0019, -0.0035, -0.0018, -0.0037, -0.0021, -0.0038, -0.0018]
- )
-
- assert np.abs(audio_slice - expected_slice).max() < 1e-4
-
- def test_musicldm_num_waveforms_per_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A hammer hitting a wooden surface"
-
- # test num_waveforms_per_prompt=1 (default)
- audios = musicldm_pipe(prompt, num_inference_steps=2).audios
-
- assert audios.shape == (1, 256)
-
- # test num_waveforms_per_prompt=1 (default) for batch of prompts
- batch_size = 2
- audios = musicldm_pipe([prompt] * batch_size, num_inference_steps=2).audios
-
- assert audios.shape == (batch_size, 256)
-
- # test num_waveforms_per_prompt for single prompt
- num_waveforms_per_prompt = 2
- audios = musicldm_pipe(prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt).audios
-
- assert audios.shape == (num_waveforms_per_prompt, 256)
-
- # test num_waveforms_per_prompt for batch of prompts
- batch_size = 2
- audios = musicldm_pipe(
- [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt
- ).audios
-
- assert audios.shape == (batch_size * num_waveforms_per_prompt, 256)
-
- def test_musicldm_audio_length_in_s(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
- vocoder_sampling_rate = musicldm_pipe.vocoder.config.sampling_rate
-
- inputs = self.get_dummy_inputs(device)
- output = musicldm_pipe(audio_length_in_s=0.016, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.016
-
- output = musicldm_pipe(audio_length_in_s=0.032, **inputs)
- audio = output.audios[0]
-
- assert audio.ndim == 1
- assert len(audio) / vocoder_sampling_rate == 0.032
-
- def test_musicldm_vocoder_model_in_dim(self):
- components = self.get_dummy_components()
- musicldm_pipe = MusicLDMPipeline(**components)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- prompt = ["hey"]
-
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- assert audio_shape == (1, 256)
-
- config = musicldm_pipe.vocoder.config
- config.model_in_dim *= 2
- musicldm_pipe.vocoder = SpeechT5HifiGan(config).to(torch_device)
- output = musicldm_pipe(prompt, num_inference_steps=1)
- audio_shape = output.audios.shape
- # waveform shape is unchanged, we just have 2x the number of mel channels in the spectrogram
- assert audio_shape == (1, 256)
-
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False)
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # The method component.dtype returns the dtype of the first parameter registered in the model, not the
- # dtype of the entire model. In the case of CLAP, the first parameter is a float64 constant (logit scale)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
-
- # Without the logit scale parameters, everything is float32
- model_dtypes.pop("text_encoder")
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # the CLAP sub-models are float32
- model_dtypes["clap_text_branch"] = components["text_encoder"].text_model.dtype
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
-
- # Once we send to fp16, all params are in half-precision, including the logit scale
- pipe.to(dtype=torch.float16)
- model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
-
-
-@nightly
-@require_torch_accelerator
-class MusicLDMPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 8, 128, 16))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "A hammer hitting a wooden surface",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 2.5,
- }
- return inputs
-
- def test_musicldm(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- inputs["num_inference_steps"] = 25
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[8680:8690]
- expected_slice = np.array(
- [-0.1042, -0.1068, -0.1235, -0.1387, -0.1428, -0.136, -0.1213, -0.1097, -0.0967, -0.0945]
- )
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
-
- def test_musicldm_lms(self):
- musicldm_pipe = MusicLDMPipeline.from_pretrained("cvssp/musicldm")
- musicldm_pipe.scheduler = LMSDiscreteScheduler.from_config(musicldm_pipe.scheduler.config)
- musicldm_pipe = musicldm_pipe.to(torch_device)
- musicldm_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- audio = musicldm_pipe(**inputs).audios[0]
-
- assert audio.ndim == 1
- assert len(audio) == 81952
-
- # check the portion of the generated audio with the largest dynamic range (reduces flakiness)
- audio_slice = audio[58020:58030]
- expected_slice = np.array([0.3592, 0.3477, 0.4084, 0.4665, 0.5048, 0.5891, 0.6461, 0.5579, 0.4595, 0.4403])
- max_diff = np.abs(expected_slice - audio_slice).max()
- assert max_diff < 1e-3
diff --git a/tests/pipelines/paint_by_example/__init__.py b/tests/pipelines/paint_by_example/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py
deleted file mode 100644
index f122c7411d..0000000000
--- a/tests/pipelines/paint_by_example/test_paint_by_example.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPImageProcessor, CLIPVisionConfig
-
-from diffusers import AutoencoderKL, PaintByExamplePipeline, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.paint_by_example import PaintByExampleImageEncoder
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = PaintByExamplePipeline
- params = IMAGE_GUIDED_IMAGE_INPAINTING_PARAMS
- batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS
- image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=9,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = PNDMScheduler(skip_prk_steps=True)
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- image_size=32,
- patch_size=4,
- )
- image_encoder = PaintByExampleImageEncoder(config, proj_size=32)
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "image_encoder": image_encoder,
- "safety_checker": None,
- "feature_extractor": feature_extractor,
- }
- return components
-
- def convert_to_pt(self, image):
- image = np.array(image.convert("RGB"))
- image = image[None].transpose(0, 3, 1, 2)
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
- return image
-
- def get_dummy_inputs(self, device="cpu", seed=0):
- # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- init_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
- mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((64, 64))
- example_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((32, 32))
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "example_image": example_image,
- "image": init_image,
- "mask_image": mask_image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_paint_by_example_inpaint(self):
- components = self.get_dummy_components()
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**components)
- pipe = pipe.to("cpu")
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs()
- output = pipe(**inputs)
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.4686, 0.5687, 0.4007, 0.5218, 0.5741, 0.4482, 0.4940, 0.4629, 0.4503])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_paint_by_example_image_tensor(self):
- device = "cpu"
- inputs = self.get_dummy_inputs()
- inputs.pop("mask_image")
- image = self.convert_to_pt(inputs.pop("image"))
- mask_image = image.clamp(0, 1) / 2
-
- # make sure here that pndm scheduler skips prk
- pipe = PaintByExamplePipeline(**self.get_dummy_components())
- pipe = pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(image=image, mask_image=mask_image[:, 0], **inputs)
- out_1 = output.images
-
- image = image.cpu().permute(0, 2, 3, 1)[0]
- mask_image = mask_image.cpu().permute(0, 2, 3, 1)[0]
-
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- mask_image = Image.fromarray(np.uint8(mask_image)).convert("RGB")
-
- output = pipe(**self.get_dummy_inputs())
- out_2 = output.images
-
- assert out_1.shape == (1, 64, 64, 3)
- assert np.abs(out_1.flatten() - out_2.flatten()).max() < 5e-2
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
-
-@nightly
-@require_torch_accelerator
-class PaintByExamplePipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_paint_by_example(self):
- # make sure here that pndm scheduler skips prk
- init_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/dog_in_bucket.png"
- )
- mask_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/mask.png"
- )
- example_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/paint_by_example/panda.jpg"
- )
-
- pipe = PaintByExamplePipeline.from_pretrained("Fantasy-Studio/Paint-by-Example")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(321)
- output = pipe(
- image=init_image,
- mask_image=mask_image,
- example_image=example_image,
- generator=generator,
- guidance_scale=5.0,
- num_inference_steps=50,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.4834, 0.4811, 0.4874, 0.5122, 0.5081, 0.5144, 0.5291, 0.5290, 0.5374])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/pia/__init__.py b/tests/pipelines/pia/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py
deleted file mode 100644
index 1156bf32da..0000000000
--- a/tests/pipelines/pia/test_pia.py
+++ /dev/null
@@ -1,448 +0,0 @@
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-import diffusers
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DPMSolverMultistepScheduler,
- LCMScheduler,
- MotionAdapter,
- PIAPipeline,
- StableDiffusionPipeline,
- UNet2DConditionModel,
- UNetMotionModel,
-)
-from diffusers.utils import is_xformers_available, logging
-from diffusers.utils.testing_utils import floats_tensor, require_accelerator, torch_device
-
-from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = PIAPipeline
- params = frozenset(
- [
- "prompt",
- "height",
- "width",
- "guidance_scale",
- "negative_prompt",
- "prompt_embeds",
- "negative_prompt_embeds",
- "cross_attention_kwargs",
- ]
- )
- batch_params = frozenset(["prompt", "image", "generator"])
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback_on_step_end",
- "callback_on_step_end_tensor_inputs",
- ]
- )
- test_layerwise_casting = True
- test_group_offloading = True
-
- def get_dummy_components(self):
- cross_attention_dim = 8
- block_out_channels = (8, 8)
-
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=block_out_channels,
- layers_per_block=2,
- sample_size=8,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=cross_attention_dim,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="linear",
- clip_sample=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=block_out_channels,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=cross_attention_dim,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- torch.manual_seed(0)
- motion_adapter = MotionAdapter(
- block_out_channels=block_out_channels,
- motion_layers_per_block=2,
- motion_norm_num_groups=2,
- motion_num_attention_heads=4,
- conv_in_channels=9,
- )
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "motion_adapter": motion_adapter,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- image = floats_tensor((1, 3, 8, 8), rng=random.Random(seed)).to(device)
- inputs = {
- "image": image,
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 7.5,
- "output_type": "pt",
- }
- return inputs
-
- def test_from_pipe_consistent_config(self):
- assert self.original_pipeline_class == StableDiffusionPipeline
- original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
- original_kwargs = {"requires_safety_checker": False}
-
- # create original_pipeline_class(sd)
- pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
-
- # original_pipeline_class(sd) -> pipeline_class
- pipe_components = self.get_dummy_components()
- pipe_additional_components = {}
- for name, component in pipe_components.items():
- if name not in pipe_original.components:
- pipe_additional_components[name] = component
-
- pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
-
- # pipeline_class -> original_pipeline_class(sd)
- original_pipe_additional_components = {}
- for name, component in pipe_original.components.items():
- if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
- original_pipe_additional_components[name] = component
-
- pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
-
- # compare the config
- original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
- original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
- assert original_config_2 == original_config
-
- def test_motion_unet_loading(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
-
- assert isinstance(pipe.unet, UNetMotionModel)
-
- def test_ip_adapter(self):
- expected_pipe_slice = None
-
- if torch_device == "cpu":
- expected_pipe_slice = np.array(
- [
- 0.5475,
- 0.5769,
- 0.4873,
- 0.5064,
- 0.4445,
- 0.5876,
- 0.5453,
- 0.4102,
- 0.5247,
- 0.5370,
- 0.3406,
- 0.4322,
- 0.3991,
- 0.3756,
- 0.5438,
- 0.4780,
- 0.5087,
- 0.5248,
- 0.6243,
- 0.5506,
- 0.3491,
- 0.5440,
- 0.6111,
- 0.5122,
- 0.5326,
- 0.5180,
- 0.5538,
- ]
- )
- return super().test_ip_adapter(expected_pipe_slice=expected_pipe_slice)
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5476, 0.4092, 0.5289, 0.4755, 0.5092, 0.5186, 0.5403, 0.5287, 0.5467])
- return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- @unittest.skip("Attention slicing is not enabled in this pipeline")
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_inference_batch_single_identical(
- self,
- batch_size=2,
- expected_max_diff=1e-4,
- additional_params_copy_to_batched_inputs=["num_inference_steps"],
- ):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for components in pipe.components.values():
- if hasattr(components, "set_default_attn_processor"):
- components.set_default_attn_processor()
-
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- inputs = self.get_dummy_inputs(torch_device)
- # Reset generator in case it is has been used in self.get_dummy_inputs
- inputs["generator"] = self.get_generator(0)
-
- logger = logging.get_logger(pipe.__module__)
- logger.setLevel(level=diffusers.logging.FATAL)
-
- # batchify inputs
- batched_inputs = {}
- batched_inputs.update(inputs)
-
- for name in self.batch_params:
- if name not in inputs:
- continue
-
- value = inputs[name]
- if name == "prompt":
- len_prompt = len(value)
- batched_inputs[name] = [value[: len_prompt // i] for i in range(1, batch_size + 1)]
- batched_inputs[name][-1] = 100 * "very long"
-
- else:
- batched_inputs[name] = batch_size * [value]
-
- if "generator" in inputs:
- batched_inputs["generator"] = [self.get_generator(i) for i in range(batch_size)]
-
- if "batch_size" in inputs:
- batched_inputs["batch_size"] = batch_size
-
- for arg in additional_params_copy_to_batched_inputs:
- batched_inputs[arg] = inputs[arg]
-
- output = pipe(**inputs)
- output_batch = pipe(**batched_inputs)
-
- assert output_batch[0].shape[0] == batch_size
-
- max_diff = np.abs(to_np(output_batch[0][0]) - to_np(output[0][0])).max()
- assert max_diff < expected_max_diff
-
- @require_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- # pipeline creates a new motion UNet under the hood. So we need to check the device from pipe.components
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [
- component.device.type for component in pipe.components.values() if hasattr(component, "device")
- ]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- def test_to_dtype(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- # pipeline creates a new motion UNet under the hood. So we need to check the dtype from pipe.components
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
-
- pipe.to(dtype=torch.float16)
- model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
- self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
-
- def test_prompt_embeds(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs.pop("prompt")
- inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
- pipe(**inputs)
-
- def test_free_init(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- pipe.enable_free_init(
- num_iters=2,
- use_fast_sampling=True,
- method="butterworth",
- order=4,
- spatial_stop_frequency=0.25,
- temporal_stop_frequency=0.25,
- )
- inputs_enable_free_init = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs_enable_free_init).frames[0]
-
- pipe.disable_free_init()
- inputs_disable_free_init = self.get_dummy_inputs(torch_device)
- frames_disable_free_init = pipe(**inputs_disable_free_init).frames[0]
-
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
- max_diff_disabled = np.abs(to_np(frames_normal) - to_np(frames_disable_free_init)).max()
- self.assertGreater(
- sum_enabled, 1e1, "Enabling of FreeInit should lead to results different from the default pipeline results"
- )
- self.assertLess(
- max_diff_disabled,
- 1e-4,
- "Disabling of FreeInit should lead to results similar to the default pipeline results",
- )
-
- def test_free_init_with_schedulers(self):
- components = self.get_dummy_components()
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- inputs_normal = self.get_dummy_inputs(torch_device)
- frames_normal = pipe(**inputs_normal).frames[0]
-
- schedulers_to_test = [
- DPMSolverMultistepScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- algorithm_type="dpmsolver++",
- steps_offset=1,
- clip_sample=False,
- ),
- LCMScheduler.from_config(
- components["scheduler"].config,
- timestep_spacing="linspace",
- beta_schedule="linear",
- steps_offset=1,
- clip_sample=False,
- ),
- ]
- components.pop("scheduler")
-
- for scheduler in schedulers_to_test:
- components["scheduler"] = scheduler
- pipe: PIAPipeline = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
- pipe.to(torch_device)
-
- pipe.enable_free_init(num_iters=2, use_fast_sampling=False)
-
- inputs = self.get_dummy_inputs(torch_device)
- frames_enable_free_init = pipe(**inputs).frames[0]
- sum_enabled = np.abs(to_np(frames_normal) - to_np(frames_enable_free_init)).sum()
-
- self.assertGreater(
- sum_enabled,
- 1e1,
- "Enabling of FreeInit should lead to results different from the default pipeline results",
- )
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- for component in pipe.components.values():
- if hasattr(component, "set_default_attn_processor"):
- component.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- output_without_offload = pipe(**inputs).frames[0]
- output_without_offload = (
- output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload
- )
-
- pipe.enable_xformers_memory_efficient_attention()
- inputs = self.get_dummy_inputs(torch_device)
- output_with_offload = pipe(**inputs).frames[0]
- output_with_offload = (
- output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload
- )
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, 1e-4, "XFormers attention should not affect the inference results")
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
diff --git a/tests/pipelines/semantic_stable_diffusion/__init__.py b/tests/pipelines/semantic_stable_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py b/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
deleted file mode 100644
index b4d82b0fb2..0000000000
--- a/tests/pipelines/semantic_stable_diffusion/test_semantic_diffusion.py
+++ /dev/null
@@ -1,617 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipeline as StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_semantic_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5753, 0.6114, 0.5001, 0.5034, 0.5470, 0.4729, 0.4971, 0.4867, 0.4867])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5122, 0.5712, 0.4825, 0.5053, 0.5646, 0.4769, 0.5179, 0.4894, 0.4994])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_semantic_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_torch_accelerator
- def test_semantic_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_accelerator
-class SemanticDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_positive_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34673113,
- 0.38492733,
- 0.37597352,
- 0.34086335,
- 0.35650748,
- 0.35579205,
- 0.3384763,
- 0.34340236,
- 0.3573271,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.41887826,
- 0.37728766,
- 0.30138272,
- 0.41416335,
- 0.41664985,
- 0.36283392,
- 0.36191246,
- 0.43364465,
- 0.43001732,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_negative_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "an image of a crowded boulevard, realistic, 4k"
- edit = {
- "editing_prompt": "crowd, crowded, people",
- "reverse_editing_direction": True,
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 8.3,
- "edit_threshold": 0.9,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 9
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.43497998,
- 0.91814065,
- 0.7540739,
- 0.55580205,
- 0.8467265,
- 0.5389691,
- 0.62574506,
- 0.58897763,
- 0.50926757,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.3089719,
- 0.30500144,
- 0.29016042,
- 0.30630964,
- 0.325687,
- 0.29419225,
- 0.2908091,
- 0.28723598,
- 0.27696294,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_multi_cond_guidance(self):
- pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a castle next to a river"
- edit = {
- "editing_prompt": ["boat on a river, boat", "monet, impression, sunrise"],
- "reverse_editing_direction": False,
- "edit_warmup_steps": [15, 18],
- "edit_guidance_scale": 6,
- "edit_threshold": [0.9, 0.8],
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 48
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.75163555,
- 0.76037145,
- 0.61785,
- 0.9189673,
- 0.8627701,
- 0.85189694,
- 0.8512813,
- 0.87012076,
- 0.8312857,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.73553365,
- 0.7537271,
- 0.74341905,
- 0.66480356,
- 0.6472925,
- 0.63039416,
- 0.64812905,
- 0.6749717,
- 0.6517102,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_guidance_fp16(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16
- )
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "a photo of a cat"
- edit = {
- "editing_prompt": ["sunglasses"],
- "reverse_editing_direction": [False],
- "edit_warmup_steps": 10,
- "edit_guidance_scale": 6,
- "edit_threshold": 0.95,
- "edit_momentum_scale": 0.5,
- "edit_mom_beta": 0.6,
- }
-
- seed = 3
- guidance_scale = 7
-
- # no sega enabled
- generator = torch.Generator(torch_device)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.34887695,
- 0.3876953,
- 0.375,
- 0.34423828,
- 0.3581543,
- 0.35717773,
- 0.3383789,
- 0.34570312,
- 0.359375,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # with sega enabled
- # generator = torch.manual_seed(seed)
- generator.manual_seed(seed)
- output = pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- **edit,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [
- 0.42285156,
- 0.36914062,
- 0.29077148,
- 0.42041016,
- 0.41918945,
- 0.35498047,
- 0.3618164,
- 0.4423828,
- 0.43115234,
- ]
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/audioldm/__init__.py b/tests/pipelines/skyreels_v2/__init__.py
similarity index 100%
rename from tests/pipelines/audioldm/__init__.py
rename to tests/pipelines/skyreels_v2/__init__.py
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2.py b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
new file mode 100644
index 0000000000..adbbf05325
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2.py
@@ -0,0 +1,137 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2Pipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2Pipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
new file mode 100644
index 0000000000..cf9070bb95
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df.py
@@ -0,0 +1,137 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=8.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
new file mode 100644
index 0000000000..7b8a299281
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_image_to_video.py
@@ -0,0 +1,215 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingImageToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
+
+
+class SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests(SkyReelsV2DiffusionForcingImageToVideoPipelineFastTests):
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ last_image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "last_image": last_image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative",
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
new file mode 100644
index 0000000000..bc6a9acbf7
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_df_video_to_video.py
@@ -0,0 +1,201 @@
+# Copyright 2025 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import AutoTokenizer, T5EncoderModel
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2DiffusionForcingVideoToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..pipeline_params import TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import (
+ PipelineTesterMixin,
+)
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2DiffusionForcingVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2DiffusionForcingVideoToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
+ batch_params = frozenset(["video", "prompt", "negative_prompt"])
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+
+ video = [Image.new("RGB", (16, 16))] * 7
+ inputs = {
+ "video": video,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "generator": generator,
+ "num_inference_steps": 4,
+ "guidance_scale": 6.0,
+ "height": 16,
+ "width": 16,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ "overlap_history": 3,
+ "num_frames": 17,
+ "base_num_frames": 5,
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ total_frames = len(inputs["video"]) + inputs["num_frames"]
+ expected_shape = (total_frames, 3, 16, 16)
+ self.assertEqual(generated_video.shape, expected_shape)
+ expected_video = torch.randn(*expected_shape)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_callback_cfg(self):
+ sig = inspect.signature(self.pipeline_class.__call__)
+ has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
+ has_callback_step_end = "callback_on_step_end" in sig.parameters
+
+ if not (has_callback_tensor_inputs and has_callback_step_end):
+ return
+
+ if "guidance_scale" not in sig.parameters:
+ return
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(torch_device)
+ pipe.set_progress_bar_config(disable=None)
+ self.assertTrue(
+ hasattr(pipe, "_callback_tensor_inputs"),
+ f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
+ )
+
+ # Track the number of callback calls for diffusion forcing pipelines
+ callback_call_count = [0] # Use list to make it mutable in closure
+
+ def callback_increase_guidance(pipe, i, t, callback_kwargs):
+ pipe._guidance_scale += 1.0
+ callback_call_count[0] += 1
+ return callback_kwargs
+
+ inputs = self.get_dummy_inputs(torch_device)
+
+ # use cfg guidance because some pipelines modify the shape of the latents
+ # outside of the denoising loop
+ inputs["guidance_scale"] = 2.0
+ inputs["callback_on_step_end"] = callback_increase_guidance
+ inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
+ _ = pipe(**inputs)[0]
+
+ # For diffusion forcing pipelines, use the actual callback count
+ # since they run multiple iterations with nested denoising loops
+ expected_guidance_scale = inputs["guidance_scale"] + callback_call_count[0]
+
+ assert pipe.guidance_scale == expected_guidance_scale
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip(
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Casting the entire pipeline will result in errors"
+ )
+ def test_float16_inference(self):
+ pass
+
+ @unittest.skip(
+ "SkyReelsV2DiffusionForcingVideoToVideoPipeline has to run in mixed precision. Save/Load the entire pipeline in FP16 will result in errors"
+ )
+ def test_save_load_float16(self):
+ pass
diff --git a/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
new file mode 100644
index 0000000000..3ca5862072
--- /dev/null
+++ b/tests/pipelines/skyreels_v2/test_skyreels_v2_image_to_video.py
@@ -0,0 +1,220 @@
+# Copyright 2024 The HuggingFace Team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from PIL import Image
+from transformers import (
+ AutoTokenizer,
+ CLIPImageProcessor,
+ CLIPVisionConfig,
+ CLIPVisionModelWithProjection,
+ T5EncoderModel,
+)
+
+from diffusers import (
+ AutoencoderKLWan,
+ SkyReelsV2ImageToVideoPipeline,
+ SkyReelsV2Transformer3DModel,
+ UniPCMultistepScheduler,
+)
+from diffusers.utils.testing_utils import enable_full_determinism
+
+from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class SkyReelsV2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = SkyReelsV2ImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ vae = AutoencoderKLWan(
+ base_dim=3,
+ z_dim=16,
+ dim_mult=[1, 1, 1, 1],
+ num_res_blocks=1,
+ temperal_downsample=[False, True, True],
+ )
+
+ torch.manual_seed(0)
+ scheduler = UniPCMultistepScheduler(flow_shift=5.0, use_flow_sigmas=True)
+ text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
+
+ torch.manual_seed(0)
+ transformer = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=32,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ image_processor = CLIPImageProcessor(crop_size=32, size=32)
+
+ components = {
+ "transformer": transformer,
+ "vae": vae,
+ "scheduler": scheduler,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "image_encoder": image_encoder,
+ "image_processor": image_processor,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ image_height = 16
+ image_width = 16
+ image = Image.new("RGB", (image_width, image_height))
+ inputs = {
+ "image": image,
+ "prompt": "dance monkey",
+ "negative_prompt": "negative", # TODO
+ "height": image_height,
+ "width": image_width,
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 6.0,
+ "num_frames": 9,
+ "max_sequence_length": 16,
+ "output_type": "pt",
+ }
+ return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ def test_inference_with_last_image(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ torch.manual_seed(0)
+ components["transformer"] = SkyReelsV2Transformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+ torch.manual_seed(0)
+ image_encoder_config = CLIPVisionConfig(
+ hidden_size=4,
+ projection_dim=4,
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ image_size=4,
+ intermediate_size=16,
+ patch_size=1,
+ )
+ components["image_encoder"] = CLIPVisionModelWithProjection(image_encoder_config)
+
+ torch.manual_seed(0)
+ components["image_processor"] = CLIPImageProcessor(crop_size=4, size=4)
+
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ image_height = 16
+ image_width = 16
+ last_image = Image.new("RGB", (image_width, image_height))
+ inputs["last_image"] = last_image
+
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+ expected_video = torch.randn(9, 3, 16, 16)
+ max_diff = np.abs(generated_video - expected_video).max()
+ self.assertLessEqual(max_diff, 1e10)
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
deleted file mode 100644
index 45fc70be23..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_attend_and_excite.py
+++ /dev/null
@@ -1,267 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- StableDiffusionAttendAndExcitePipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- load_numpy,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-torch.backends.cuda.matmul.allow_tf32 = False
-
-
-@skip_mps
-class StableDiffusionAttendAndExcitePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionAttendAndExcitePipeline
- test_attention_slicing = False
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"token_indices"})
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a cat and a frog",
- "token_indices": [2, 5],
- "generator": generator,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- "max_iter_to_alter": 2,
- "thresholds": {0: 0.7},
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.6391, 0.6290, 0.4860, 0.5134, 0.5550, 0.4577, 0.5033, 0.5023, 0.4538])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice, expected_max_difference=3e-3)
-
- def test_inference(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- self.assertEqual(image.shape, (1, 64, 64, 3))
- expected_slice = np.array(
- [0.63905364, 0.62897307, 0.48599017, 0.5133624, 0.5550048, 0.45769516, 0.50326973, 0.5023139, 0.45384496]
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_sequential_cpu_offload_forward_pass(self):
- super().test_sequential_cpu_offload_forward_pass(expected_max_diff=5e-4)
-
- def test_inference_batch_consistent(self):
- # NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
- self._test_inference_batch_consistent(batch_sizes=[1, 2])
-
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=7e-4)
-
- def test_pt_np_pil_outputs_equivalent(self):
- super().test_pt_np_pil_outputs_equivalent(expected_max_diff=5e-4)
-
- def test_save_load_local(self):
- super().test_save_load_local(expected_max_difference=5e-4)
-
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=4e-4)
-
- def test_karras_schedulers_shape(self):
- super().test_karras_schedulers_shape(num_inference_steps_for_strength_for_iterations=3)
-
- def test_from_pipe_consistent_forward_pass_cpu_offload(self):
- super().test_from_pipe_consistent_forward_pass_cpu_offload(expected_max_diff=5e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
- # Attend and excite requires being able to run a backward pass at
- # inference time. There's no deterministic backward operator for pad
-
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- torch.use_deterministic_algorithms(False)
-
- @classmethod
- def tearDownClass(cls):
- super().tearDownClass()
- torch.use_deterministic_algorithms(True)
-
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_attend_and_excite_fp16(self):
- generator = torch.manual_seed(51)
-
- pipe = StableDiffusionAttendAndExcitePipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.to(torch_device)
-
- prompt = "a painting of an elephant with glasses"
- token_indices = [5, 7]
-
- image = pipe(
- prompt=prompt,
- token_indices=token_indices,
- guidance_scale=7.5,
- generator=generator,
- num_inference_steps=5,
- max_iter_to_alter=5,
- output_type="np",
- ).images[0]
-
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/attend-and-excite/elephant_glasses.npy"
- )
- max_diff = numpy_cosine_similarity_distance(image.flatten(), expected_image.flatten())
- assert max_diff < 5e-1
diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
deleted file mode 100644
index 9f8870af7b..0000000000
--- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_diffedit.py
+++ /dev/null
@@ -1,452 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMInverseScheduler,
- DDIMScheduler,
- DPMSolverMultistepInverseScheduler,
- DPMSolverMultistepScheduler,
- StableDiffusionDiffEditPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class StableDiffusionDiffEditPipelineFastTests(
- PipelineLatentTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
-):
- pipeline_class = StableDiffusionDiffEditPipeline
- params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"}
- batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"}
- image_params = frozenset(
- []
- ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess
- image_latents_params = frozenset([])
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- inverse_scheduler = DDIMInverseScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_zero=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "inverse_scheduler": inverse_scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- mask = floats_tensor((1, 16, 16), rng=random.Random(seed)).to(device)
- latents = floats_tensor((1, 2, 4, 16, 16), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a dog and a newt",
- "mask_image": mask,
- "image_latents": latents,
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_mask_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "source_prompt": "a cat and a frog",
- "target_prompt": "a dog and a newt",
- "generator": generator,
- "num_inference_steps": 2,
- "num_maps_per_mask": 2,
- "mask_encode_strength": 1.0,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
-
- return inputs
-
- def get_dummy_inversion_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image": image,
- "prompt": "a cat and a frog",
- "generator": generator,
- "num_inference_steps": 2,
- "inpaint_strength": 1.0,
- "guidance_scale": 6.0,
- "decode_latents": True,
- "output_type": "np",
- }
- return inputs
-
- def test_save_load_optional_components(self):
- if not hasattr(self.pipeline_class, "_optional_components"):
- return
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- # set all optional components to None and update pipeline config accordingly
- for optional_component in pipe._optional_components:
- setattr(pipe, optional_component, None)
- pipe.register_modules(**dict.fromkeys(pipe._optional_components))
-
- inputs = self.get_dummy_inputs(torch_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for optional_component in pipe._optional_components:
- self.assertTrue(
- getattr(pipe_loaded, optional_component) is None,
- f"`{optional_component}` did not stay set to None after loading.",
- )
-
- inputs = self.get_dummy_inputs(torch_device)
- output_loaded = pipe_loaded(**inputs)[0]
-
- max_diff = np.abs(output - output_loaded).max()
- self.assertLess(max_diff, 1e-4)
-
- def test_mask(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_mask_inputs(device)
- mask = pipe.generate_mask(**inputs)
- mask_slice = mask[0, -3:, -3:]
-
- self.assertEqual(mask.shape, (1, 16, 16))
- expected_slice = np.array([0] * 9)
- max_diff = np.abs(mask_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
- self.assertEqual(mask[0, -3, -4], 0)
-
- def test_inversion(self):
- device = "cpu"
-
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5160, 0.5115, 0.5060, 0.5456, 0.4704, 0.5060, 0.5019, 0.4405, 0.4726],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=5e-3)
-
- def test_inversion_dpm(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- scheduler_args = {"beta_start": 0.00085, "beta_end": 0.012, "beta_schedule": "scaled_linear"}
- components["scheduler"] = DPMSolverMultistepScheduler(**scheduler_args)
- components["inverse_scheduler"] = DPMSolverMultistepInverseScheduler(**scheduler_args)
-
- pipe = self.pipeline_class(**components)
- pipe.to(device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inversion_inputs(device)
- image = pipe.invert(**inputs).images
- image_slice = image[0, -1, -3:, -3:]
-
- self.assertEqual(image.shape, (2, 32, 32, 3))
- expected_slice = np.array(
- [0.5305, 0.4673, 0.5314, 0.5308, 0.4886, 0.5279, 0.5142, 0.4724, 0.4892],
- )
- max_diff = np.abs(image_slice.flatten() - expected_slice).max()
- self.assertLessEqual(max_diff, 1e-3)
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@require_torch_accelerator
-@nightly
-class StableDiffusionDiffEditPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
- raw_image = raw_image.convert("RGB").resize((256, 256))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_full(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1-base", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- pipe.scheduler.clip_sample = True
-
- pipe.inverse_scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload(device=torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=5,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=5,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((256, 256))
- )
- / 255
- )
-
- assert numpy_cosine_similarity_distance(expected_image.flatten(), image.flatten()) < 2e-1
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionDiffEditPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @classmethod
- def setUpClass(cls):
- raw_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/diffedit/fruit.png"
- )
-
- raw_image = raw_image.convert("RGB").resize((768, 768))
-
- cls.raw_image = raw_image
-
- def test_stable_diffusion_diffedit_dpm(self):
- generator = torch.manual_seed(0)
-
- pipe = StableDiffusionDiffEditPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-1", safety_checker=None, torch_dtype=torch.float16
- )
- pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
- pipe.inverse_scheduler = DPMSolverMultistepInverseScheduler.from_config(pipe.scheduler.config)
- pipe.enable_model_cpu_offload()
- pipe.set_progress_bar_config(disable=None)
-
- source_prompt = "a bowl of fruit"
- target_prompt = "a bowl of pears"
-
- mask_image = pipe.generate_mask(
- image=self.raw_image,
- source_prompt=source_prompt,
- target_prompt=target_prompt,
- generator=generator,
- )
-
- inv_latents = pipe.invert(
- prompt=source_prompt,
- image=self.raw_image,
- inpaint_strength=0.7,
- generator=generator,
- num_inference_steps=25,
- ).latents
-
- image = pipe(
- prompt=target_prompt,
- mask_image=mask_image,
- image_latents=inv_latents,
- generator=generator,
- negative_prompt=source_prompt,
- inpaint_strength=0.7,
- num_inference_steps=25,
- output_type="np",
- ).images[0]
-
- expected_image = (
- np.array(
- load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/diffedit/pears.png"
- ).resize((768, 768))
- )
- / 255
- )
- assert np.abs((expected_image - image).max()) < 5e-1
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
index 577ac4ebdd..2179ec8e22 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3.py
@@ -2,7 +2,6 @@ import gc
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -233,7 +232,6 @@ class StableDiffusion3PipelineFastTests(unittest.TestCase, PipelineTesterMixin):
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3PipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
index f5b5e63a81..7f913cb63d 100644
--- a/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
+++ b/tests/pipelines/stable_diffusion_3/test_pipeline_stable_diffusion_3_img2img.py
@@ -3,7 +3,6 @@ import random
import unittest
import numpy as np
-import pytest
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
@@ -168,7 +167,6 @@ class StableDiffusion3Img2ImgPipelineFastTests(PipelineLatentTesterMixin, unitte
@slow
@require_big_accelerator
-@pytest.mark.big_accelerator
class StableDiffusion3Img2ImgPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"
diff --git a/tests/pipelines/stable_diffusion_gligen/__init__.py b/tests/pipelines/stable_diffusion_gligen/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py b/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
deleted file mode 100644
index 5d56f16803..0000000000
--- a/tests/pipelines/stable_diffusion_gligen/test_stable_diffusion_gligen.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import enable_full_determinism
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenPipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENPipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_gligen_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENPipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip("Test not supported as tokenizer is used for parsing bounding boxes.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/__init__.py b/tests/pipelines/stable_diffusion_gligen_text_image/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
deleted file mode 100644
index 3f092e02dd..0000000000
--- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py
+++ /dev/null
@@ -1,215 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPProcessor,
- CLIPTextConfig,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- StableDiffusionGLIGENTextImagePipeline,
- UNet2DConditionModel,
-)
-from diffusers.pipelines.stable_diffusion import CLIPImageProjection
-from diffusers.utils import load_image
-from diffusers.utils.testing_utils import enable_full_determinism, torch_device
-
-from ..pipeline_params import (
- TEXT_TO_IMAGE_BATCH_PARAMS,
- TEXT_TO_IMAGE_IMAGE_PARAMS,
- TEXT_TO_IMAGE_PARAMS,
-)
-from ..test_pipelines_common import (
- PipelineFromPipeTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class GligenTextImagePipelineFastTests(
- PipelineLatentTesterMixin,
- PipelineKarrasSchedulerTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionGLIGENTextImagePipeline
- params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- supports_dduf = False
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- attention_type="gated-text-image",
- )
- # unet.position_net = PositionNet(32,32)
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- image_encoder_config = CLIPVisionConfig(
- hidden_size=32,
- projection_dim=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- )
- image_encoder = CLIPVisionModelWithProjection(image_encoder_config)
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
-
- image_project = CLIPImageProjection(hidden_size=32)
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": image_encoder,
- "image_project": image_project,
- "processor": processor,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- gligen_images = load_image(
- "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/gligen/livingroom_modern.png"
- )
- inputs = {
- "prompt": "A modern livingroom",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "gligen_phrases": ["a birthday cake"],
- "gligen_images": [gligen_images],
- "gligen_boxes": [[0.2676, 0.6088, 0.4773, 0.7183]],
- "output_type": "np",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.5052, 0.5546, 0.4567, 0.4770, 0.5195, 0.4085, 0.5026, 0.4909, 0.4495])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_stable_diffusion_gligen_text_image_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5069, 0.5561, 0.4577, 0.4792, 0.5203, 0.4089, 0.5039, 0.4919, 0.4499])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_gligen_k_euler_ancestral(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionGLIGENTextImagePipeline(**components)
- sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.425, 0.494, 0.429, 0.469, 0.525, 0.417, 0.533, 0.5, 0.47])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_attention_slicing_forward_pass(self):
- super().test_attention_slicing_forward_pass(expected_max_diff=3e-3)
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=3, expected_max_diff=3e-3)
-
- @unittest.skip(
- "Test not supported because of the use of `text_encoder` in `get_cross_attention_kwargs_with_grounded()`."
- )
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/__init__.py b/tests/pipelines/stable_diffusion_k_diffusion/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py b/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
deleted file mode 100644
index dc7e62078a..0000000000
--- a/tests/pipelines/stable_diffusion_k_diffusion/test_stable_diffusion_k_diffusion.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionKDiffusionPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_1(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.0447, 0.0492, 0.0468, 0.0408, 0.0383, 0.0408, 0.0354, 0.0380, 0.0339])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_2(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=9.0, num_inference_steps=20, output_type="np")
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1237, 0.1320, 0.1438, 0.1359, 0.1390, 0.1132, 0.1277, 0.1175, 0.1112])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-1
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=15,
- output_type="np",
- use_karras_sigmas=True,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array(
- [0.11381689, 0.12112921, 0.1389457, 0.12549606, 0.1244964, 0.10831517, 0.11562866, 0.10867816, 0.10499048]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionKDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=20,
- output_type="np",
- ).images
-
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_ldm3d/__init__.py b/tests/pipelines/stable_diffusion_ldm3d/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py b/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
deleted file mode 100644
index 936e22b470..0000000000
--- a/tests/pipelines/stable_diffusion_ldm3d/test_stable_diffusion_ldm3d.py
+++ /dev/null
@@ -1,326 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- PNDMScheduler,
- StableDiffusionLDM3DPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-
-
-enable_full_determinism()
-
-
-class StableDiffusionLDM3DPipelineFastTests(unittest.TestCase):
- pipeline_class = StableDiffusionLDM3DPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=6,
- out_channels=6,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
-
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- image_slice_rgb = rgb[0, -3:, -3:, -1]
- image_slice_depth = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37338176, 0.70247, 0.74203193, 0.51643604, 0.58256793, 0.60932136, 0.4181095, 0.48355877, 0.46535262]
- )
- expected_slice_depth = np.array([103.46727, 85.812004, 87.849236])
-
- assert np.abs(image_slice_rgb.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(image_slice_depth.flatten() - expected_slice_depth).max() < 1e-2
-
- def test_stable_diffusion_prompt_embeds(self):
- components = self.get_dummy_components()
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["prompt"] = 3 * [inputs["prompt"]]
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_1, depth_slice_1 = output.rgb, output.depth
- rgb_slice_1 = rgb_slice_1[0, -3:, -3:, -1]
- depth_slice_1 = depth_slice_1[0, -3:, -1]
-
- inputs = self.get_dummy_inputs(torch_device)
- prompt = 3 * [inputs.pop("prompt")]
-
- text_inputs = ldm3d_pipe.tokenizer(
- prompt,
- padding="max_length",
- max_length=ldm3d_pipe.tokenizer.model_max_length,
- truncation=True,
- return_tensors="pt",
- )
- text_inputs = text_inputs["input_ids"].to(torch_device)
-
- prompt_embeds = ldm3d_pipe.text_encoder(text_inputs)[0]
-
- inputs["prompt_embeds"] = prompt_embeds
-
- # forward
- output = ldm3d_pipe(**inputs)
- rgb_slice_2, depth_slice_2 = output.rgb, output.depth
- rgb_slice_2 = rgb_slice_2[0, -3:, -3:, -1]
- depth_slice_2 = depth_slice_2[0, -3:, -1]
-
- assert np.abs(rgb_slice_1.flatten() - rgb_slice_2.flatten()).max() < 1e-4
- assert np.abs(depth_slice_1.flatten() - depth_slice_2.flatten()).max() < 1e-4
-
- def test_stable_diffusion_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(skip_prk_steps=True)
- ldm3d_pipe = StableDiffusionLDM3DPipeline(**components)
- ldm3d_pipe = ldm3d_pipe.to(device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = ldm3d_pipe(**inputs, negative_prompt=negative_prompt)
-
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1]
- depth_slice = depth[0, -3:, -1]
-
- assert rgb.shape == (1, 64, 64, 3)
- assert depth.shape == (1, 64, 64)
-
- expected_slice_rgb = np.array(
- [0.37044, 0.71811503, 0.7223251, 0.48603675, 0.5638391, 0.6364948, 0.42833704, 0.4901315, 0.47926217]
- )
- expected_slice_depth = np.array([107.84738, 84.62802, 89.962135])
- assert np.abs(rgb_slice.flatten() - expected_slice_rgb).max() < 1e-2
- assert np.abs(depth_slice.flatten() - expected_slice_depth).max() < 1e-2
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionLDM3DPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d_stable_diffusion(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d")
- ldm3d_pipe = ldm3d_pipe.to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
- rgb_slice = rgb[0, -3:, -3:, -1].flatten()
- depth_slice = rgb[0, -3:, -1].flatten()
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512)
-
- expected_slice_rgb = np.array(
- [0.53805465, 0.56707305, 0.5486515, 0.57012236, 0.5814511, 0.56253487, 0.54843014, 0.55092263, 0.6459706]
- )
- expected_slice_depth = np.array(
- [0.9263781, 0.6678672, 0.5486515, 0.92202145, 0.67831135, 0.56253487, 0.9241694, 0.7551478, 0.6459706]
- )
- assert np.abs(rgb_slice - expected_slice_rgb).max() < 3e-3
- assert np.abs(depth_slice - expected_slice_depth).max() < 3e-3
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
- generator = torch.Generator(device=generator_device).manual_seed(seed)
- latents = np.random.RandomState(seed).standard_normal((1, 4, 64, 64))
- latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
- inputs = {
- "prompt": "a photograph of an astronaut riding a horse",
- "latents": latents,
- "generator": generator,
- "num_inference_steps": 50,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_ldm3d(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.495586
- expected_rgb_std = 0.33795515
- expected_depth_mean = 112.48518
- expected_depth_std = 98.489746
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
-
- def test_ldm3d_v2(self):
- ldm3d_pipe = StableDiffusionLDM3DPipeline.from_pretrained("Intel/ldm3d-4c").to(torch_device)
- ldm3d_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_inputs(torch_device)
- output = ldm3d_pipe(**inputs)
- rgb, depth = output.rgb, output.depth
-
- expected_rgb_mean = 0.4194127
- expected_rgb_std = 0.35375586
- expected_depth_mean = 0.5638502
- expected_depth_std = 0.34686103
-
- assert rgb.shape == (1, 512, 512, 3)
- assert depth.shape == (1, 512, 512, 1)
- assert np.abs(expected_rgb_mean - rgb.mean()) < 1e-3
- assert np.abs(expected_rgb_std - rgb.std()) < 1e-3
- assert np.abs(expected_depth_mean - depth.mean()) < 1e-3
- assert np.abs(expected_depth_std - depth.std()) < 1e-3
diff --git a/tests/pipelines/stable_diffusion_panorama/__init__.py b/tests/pipelines/stable_diffusion_panorama/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py b/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
deleted file mode 100644
index 61f91cae2b..0000000000
--- a/tests/pipelines/stable_diffusion_panorama/test_stable_diffusion_panorama.py
+++ /dev/null
@@ -1,444 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- EulerAncestralDiscreteScheduler,
- LMSDiscreteScheduler,
- PNDMScheduler,
- StableDiffusionPanoramaPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- backend_max_memory_allocated,
- backend_reset_max_memory_allocated,
- backend_reset_peak_memory_stats,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class StableDiffusionPanoramaPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionPanoramaPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- scheduler = DDIMScheduler()
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- # Setting height and width to None to prevent OOMs on CPU.
- "height": None,
- "width": None,
- "num_inference_steps": 1,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6186, 0.5374, 0.4915, 0.4135, 0.4114, 0.4563, 0.5128, 0.4977, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_circular_padding_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs, circular_padding=True).images
- image_slice = image[0, -3:, -3:, -1]
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # override to speed the overall test timing up.
- def test_inference_batch_consistent(self):
- super().test_inference_batch_consistent(batch_sizes=[1, 2])
-
- # override to speed the overall test timing up.
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(batch_size=2, expected_max_diff=5.0e-3)
-
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1e-1)
-
- def test_stable_diffusion_panorama_negative_prompt(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- negative_prompt = "french fries"
- output = sd_pipe(**inputs, negative_prompt=negative_prompt)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6187, 0.5375, 0.4915, 0.4136, 0.4114, 0.4563, 0.5128, 0.4976, 0.4757])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_views_batch_circular_padding(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_euler(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = EulerAncestralDiscreteScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.4024, 0.6510, 0.4901, 0.5378, 0.5813, 0.5622, 0.4795, 0.4467, 0.4952])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- components["scheduler"] = PNDMScheduler(
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
- )
- sd_pipe = StableDiffusionPanoramaPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- image = sd_pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.6391, 0.6291, 0.4861, 0.5134, 0.5552, 0.4578, 0.5032, 0.5023, 0.4539])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPanoramaNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, seed=0):
- generator = torch.manual_seed(seed)
- inputs = {
- "prompt": "a photo of the dolomites",
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 7.5,
- "output_type": "np",
- }
- return inputs
-
- def test_stable_diffusion_panorama_default(self):
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
-
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- 0.36968392,
- 0.27025372,
- 0.32446766,
- 0.28379387,
- 0.36363274,
- 0.30733347,
- 0.27100027,
- 0.27054125,
- 0.25536096,
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_k_lms(self):
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(
- "stabilityai/stable-diffusion-2-base", safety_checker=None
- )
- pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)
- pipe.unet.set_default_attn_processor()
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- image = pipe(**inputs).images
- image_slice = image[0, -3:, -3:, -1].flatten()
- assert image.shape == (1, 512, 2048, 3)
-
- expected_slice = np.array(
- [
- [
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- 0.0,
- ]
- ]
- )
-
- assert np.abs(expected_slice - image_slice).max() < 1e-2
-
- def test_stable_diffusion_panorama_intermediate_state(self):
- number_of_steps = 0
-
- def callback_fn(step: int, timestep: int, latents: torch.Tensor) -> None:
- callback_fn.has_been_called = True
- nonlocal number_of_steps
- number_of_steps += 1
- if step == 1:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18681869,
- 0.33907816,
- 0.5361276,
- 0.14432865,
- -0.02856611,
- -0.73941123,
- 0.23397987,
- 0.47322682,
- -0.37823164,
- ]
- )
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
- elif step == 2:
- latents = latents.detach().cpu().numpy()
- assert latents.shape == (1, 4, 64, 256)
- latents_slice = latents[0, -3:, -3:, -1]
-
- expected_slice = np.array(
- [
- 0.18539645,
- 0.33987248,
- 0.5378559,
- 0.14437142,
- -0.02455261,
- -0.7338317,
- 0.23990755,
- 0.47356272,
- -0.3786505,
- ]
- )
-
- assert np.abs(latents_slice.flatten() - expected_slice).max() < 5e-2
-
- callback_fn.has_been_called = False
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs()
- pipe(**inputs, callback=callback_fn, callback_steps=1)
- assert callback_fn.has_been_called
- assert number_of_steps == 3
-
- def test_stable_diffusion_panorama_pipeline_with_sequential_cpu_offloading(self):
- backend_empty_cache(torch_device)
- backend_reset_max_memory_allocated(torch_device)
- backend_reset_peak_memory_stats(torch_device)
-
- model_ckpt = "stabilityai/stable-diffusion-2-base"
- scheduler = DDIMScheduler.from_pretrained(model_ckpt, subfolder="scheduler")
- pipe = StableDiffusionPanoramaPipeline.from_pretrained(model_ckpt, scheduler=scheduler, safety_checker=None)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing(1)
- pipe.enable_sequential_cpu_offload()
-
- inputs = self.get_inputs()
- _ = pipe(**inputs)
-
- mem_bytes = backend_max_memory_allocated(torch_device)
- # make sure that less than 5.2 GB is allocated
- assert mem_bytes < 5.5 * 10**9
diff --git a/tests/pipelines/stable_diffusion_safe/__init__.py b/tests/pipelines/stable_diffusion_safe/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py b/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
deleted file mode 100644
index 5d81cff3e0..0000000000
--- a/tests/pipelines/stable_diffusion_safe/test_safe_diffusion.py
+++ /dev/null
@@ -1,497 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
-from diffusers.pipelines.stable_diffusion_safe import StableDiffusionPipelineSafe as StableDiffusionPipeline
-from diffusers.utils.testing_utils import (
- Expectations,
- backend_empty_cache,
- floats_tensor,
- nightly,
- require_accelerator,
- require_torch_accelerator,
- torch_device,
-)
-
-
-class SafeDiffusionPipelineFastTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- @property
- def dummy_image(self):
- batch_size = 1
- num_channels = 3
- sizes = (32, 32)
-
- image = floats_tensor((batch_size, num_channels) + sizes, rng=random.Random(0)).to(torch_device)
- return image
-
- @property
- def dummy_cond_unet(self):
- torch.manual_seed(0)
- model = UNet2DConditionModel(
- block_out_channels=(32, 64),
- layers_per_block=2,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=32,
- )
- return model
-
- @property
- def dummy_vae(self):
- torch.manual_seed(0)
- model = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- return model
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config)
-
- @property
- def dummy_extractor(self):
- def extract(*args, **kwargs):
- class Out:
- def __init__(self):
- self.pixel_values = torch.ones([0])
-
- def to(self, device):
- self.pixel_values.to(device)
- return self
-
- return Out()
-
- return extract
-
- def test_safe_diffusion_ddim(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
-
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5756, 0.6118, 0.5005, 0.5041, 0.5471, 0.4726, 0.4976, 0.4865, 0.4864])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_pndm(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.Generator(device=device).manual_seed(0)
- output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
-
- image = output.images
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_tuple = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=6.0,
- num_inference_steps=2,
- output_type="np",
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
- expected_slice = np.array([0.5125, 0.5716, 0.4828, 0.5060, 0.5650, 0.4768, 0.5185, 0.4895, 0.4993])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_no_safety_checker(self):
- pipe = StableDiffusionPipeline.from_pretrained(
- "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
- )
- assert isinstance(pipe, StableDiffusionPipeline)
- assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
- assert pipe.safety_checker is None
-
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- # check that there's no error when saving a pipeline with one of the models being None
- with tempfile.TemporaryDirectory() as tmpdirname:
- pipe.save_pretrained(tmpdirname)
- pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
-
- # sanity check that the pipeline still works
- assert pipe.safety_checker is None
- image = pipe("example prompt", num_inference_steps=2).images[0]
- assert image is not None
-
- @require_accelerator
- def test_stable_diffusion_fp16(self):
- """Test that stable diffusion works with fp16"""
- unet = self.dummy_cond_unet
- scheduler = PNDMScheduler(skip_prk_steps=True)
- vae = self.dummy_vae
- bert = self.dummy_text_encoder
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- # put models in fp16
- unet = unet.half()
- vae = vae.half()
- bert = bert.half()
-
- # make sure here that pndm scheduler skips prk
- sd_pipe = StableDiffusionPipeline(
- unet=unet,
- scheduler=scheduler,
- vae=vae,
- text_encoder=bert,
- tokenizer=tokenizer,
- safety_checker=None,
- feature_extractor=self.dummy_extractor,
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "A painting of a squirrel eating a burger"
- image = sd_pipe([prompt], num_inference_steps=2, output_type="np").images
-
- assert image.shape == (1, 64, 64, 3)
-
-
-@nightly
-@require_torch_accelerator
-class SafeDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_harm_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "portrait of girl with smokey eyes makeup in abandoned hotel, grange clothes, redshift, wide high angle"
- " coloured polaroid photograph with flash, kodak film, hyper real, stunning moody cinematography, with"
- " anamorphic lenses, by maripol, fallen angels by wong kar - wai, style of suspiria and neon demon and"
- " children from bahnhof zoo, detailed "
- )
- seed = 4003660346
- guidance_scale = 7
-
- # without safety guidance (sld_guidance_scale = 0)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
- ("cuda", 7): [0.2278, 0.2231, 0.2249, 0.2333, 0.2303, 0.1885, 0.2273, 0.2144, 0.2176],
- ("cuda", 8): [0.0076, 0.0058, 0.0012, 0, 0.0047, 0.0046, 0, 0, 0],
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- # without safety guidance (strong configuration)
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719]
-
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
- ("cuda", 7): [0.2383, 0.2276, 0.236, 0.2192, 0.2186, 0.2053, 0.1971, 0.1901, 0.1719],
- ("cuda", 8): [0.0443, 0.0439, 0.0381, 0.0336, 0.0408, 0.0345, 0.0405, 0.0338, 0.0293],
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained(
- "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None
- )
- sd_pipe.scheduler = LMSDiscreteScheduler.from_config(sd_pipe.scheduler.config)
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = "padme amidala taking a bath artwork, safe for work, no nudity"
- seed = 2734971755
- guidance_scale = 7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.3244, 0.3355, 0.3260, 0.3123, 0.3246, 0.3426, 0.3109, 0.3471, 0.4001],
- ("cuda", 7): [0.3502, 0.3622, 0.3396, 0.3642, 0.3478, 0.3318, 0.35, 0.3348, 0.3297],
- ("cuda", 8): [0.3605, 0.3684, 0.3712, 0.3624, 0.3675, 0.3726, 0.3494, 0.3748, 0.4044],
- }
- )
- expected_slice = expected_slices.get_expectation()
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): [0.6178, 0.6260, 0.6194, 0.6435, 0.6265, 0.6461, 0.6567, 0.6576, 0.6444],
- ("cuda", 7): [0.5531, 0.5206, 0.4895, 0.5156, 0.5182, 0.4751, 0.4802, 0.4803, 0.4443],
- ("cuda", 8): [0.5892, 0.5959, 0.5914, 0.6123, 0.5982, 0.6141, 0.6180, 0.6262, 0.6171],
- }
- )
-
- print(f"image_slice: {image_slice.flatten()}")
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_nudity_safetychecker_safe_stable_diffusion(self):
- sd_pipe = StableDiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5")
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- prompt = (
- "the four horsewomen of the apocalypse, painting by tom of finland, gaston bussiere, craig mullins, j. c."
- " leyendecker"
- )
- seed = 1044355234
- guidance_scale = 12
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=0,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
-
- assert image.shape == (1, 512, 512, 3)
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-7
-
- generator = torch.manual_seed(seed)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=guidance_scale,
- num_inference_steps=50,
- output_type="np",
- width=512,
- height=512,
- sld_guidance_scale=2000,
- sld_warmup_steps=7,
- sld_threshold=0.025,
- sld_momentum_scale=0.5,
- sld_mom_beta=0.7,
- )
-
- image = output.images
- image_slice = image[0, -3:, -3:, -1]
- expected_slices = Expectations(
- {
- ("xpu", 3): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
- ("cuda", 7): np.array([0.5818, 0.6285, 0.6835, 0.6019, 0.625, 0.6754, 0.6096, 0.6334, 0.6561]),
- ("cuda", 8): np.array([0.0695, 0.1244, 0.1831, 0.0527, 0.0444, 0.1660, 0.0572, 0.0677, 0.1551]),
- }
- )
- expected_slice = expected_slices.get_expectation()
-
- assert image.shape == (1, 512, 512, 3)
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
diff --git a/tests/pipelines/stable_diffusion_sag/__init__.py b/tests/pipelines/stable_diffusion_sag/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py b/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
deleted file mode 100644
index 1d18403322..0000000000
--- a/tests/pipelines/stable_diffusion_sag/test_stable_diffusion_sag.py
+++ /dev/null
@@ -1,245 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- DEISMultistepScheduler,
- DPMSolverMultistepScheduler,
- EulerDiscreteScheduler,
- StableDiffusionSAGPipeline,
- UNet2DConditionModel,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- IPAdapterTesterMixin,
- PipelineFromPipeTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
-)
-
-
-enable_full_determinism()
-
-
-class StableDiffusionSAGPipelineFastTests(
- IPAdapterTesterMixin,
- PipelineLatentTesterMixin,
- PipelineTesterMixin,
- PipelineFromPipeTesterMixin,
- unittest.TestCase,
-):
- pipeline_class = StableDiffusionSAGPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet2DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=2,
- sample_size=8,
- norm_num_groups=1,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- cross_attention_dim=8,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[4, 8],
- norm_num_groups=1,
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=8,
- num_hidden_layers=2,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- pad_token_id=1,
- vocab_size=1000,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "safety_checker": None,
- "feature_extractor": None,
- "image_encoder": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": ".",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 1.0,
- "sag_scale": 1.0,
- "output_type": "np",
- }
- return inputs
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=3e-3)
-
- @unittest.skip("Not necessary to test here.")
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
- def test_pipeline_different_schedulers(self):
- pipeline = self.pipeline_class(**self.get_dummy_components())
- inputs = self.get_dummy_inputs("cpu")
-
- expected_image_size = (16, 16, 3)
- for scheduler_cls in [DDIMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler]:
- pipeline.scheduler = scheduler_cls.from_config(pipeline.scheduler.config)
- image = pipeline(**inputs).images[0]
-
- shape = image.shape
- assert shape == expected_image_size
-
- pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
-
- with self.assertRaises(ValueError):
- # Karras schedulers are not supported
- image = pipeline(**inputs).images[0]
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@require_torch_accelerator
-class StableDiffusionPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_1(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.1568, 0.1738, 0.1695, 0.1693, 0.1507, 0.1705, 0.1547, 0.1751, 0.1949])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt], generator=generator, guidance_scale=7.5, sag_scale=1.0, num_inference_steps=20, output_type="np"
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.3459, 0.2876, 0.2537, 0.3002, 0.2671, 0.2160, 0.3026, 0.2262, 0.2371])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
-
- def test_stable_diffusion_2_non_square(self):
- sag_pipe = StableDiffusionSAGPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
- sag_pipe = sag_pipe.to(torch_device)
- sag_pipe.set_progress_bar_config(disable=None)
-
- prompt = "."
- generator = torch.manual_seed(0)
- output = sag_pipe(
- [prompt],
- width=768,
- height=512,
- generator=generator,
- guidance_scale=7.5,
- sag_scale=1.0,
- num_inference_steps=20,
- output_type="np",
- )
-
- image = output.images
-
- assert image.shape == (1, 512, 768, 3)
diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
deleted file mode 100644
index ae131d1d4f..0000000000
--- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_k_diffusion.py
+++ /dev/null
@@ -1,178 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-
-from diffusers import StableDiffusionXLKDiffusionPipeline
-from diffusers.utils.testing_utils import (
- Expectations,
- backend_empty_cache,
- enable_full_determinism,
- require_torch_accelerator,
- slow,
- torch_device,
-)
-
-
-enable_full_determinism()
-
-
-@slow
-@require_torch_accelerator
-class StableDiffusionXLKPipelineIntegrationTests(unittest.TestCase):
- dtype = torch.float16
-
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_stable_diffusion_xl(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_euler")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=9.0,
- num_inference_steps=2,
- height=512,
- width=512,
- output_type="np",
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slice = np.array([0.5420, 0.5038, 0.2439, 0.5371, 0.4660, 0.1906, 0.5221, 0.4290, 0.2566])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_karras_sigmas(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_2m")
-
- prompt = "A painting of a squirrel eating a burger"
- generator = torch.manual_seed(0)
- output = sd_pipe(
- [prompt],
- generator=generator,
- guidance_scale=7.5,
- num_inference_steps=2,
- output_type="np",
- use_karras_sigmas=True,
- height=512,
- width=512,
- )
-
- image = output.images
-
- image_slice = image[0, -3:, -3:, -1]
-
- assert image.shape == (1, 512, 512, 3)
- expected_slices = Expectations(
- {
- ("xpu", 3): np.array(
- [
- 0.6128,
- 0.6108,
- 0.6109,
- 0.5997,
- 0.5988,
- 0.5948,
- 0.5903,
- 0.597,
- 0.5973,
- ]
- ),
- ("cuda", 7): np.array(
- [
- 0.6418,
- 0.6424,
- 0.6462,
- 0.6271,
- 0.6314,
- 0.6295,
- 0.6249,
- 0.6339,
- 0.6335,
- ]
- ),
- }
- )
-
- expected_slice = expected_slices.get_expectation()
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_stable_diffusion_noise_sampler_seed(self):
- sd_pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
- "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=self.dtype
- )
- sd_pipe = sd_pipe.to(torch_device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- sd_pipe.set_scheduler("sample_dpmpp_sde")
-
- prompt = "A painting of a squirrel eating a burger"
- seed = 0
- images1 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- images2 = sd_pipe(
- [prompt],
- generator=torch.manual_seed(seed),
- noise_sampler_seed=seed,
- guidance_scale=9.0,
- num_inference_steps=2,
- output_type="np",
- height=512,
- width=512,
- ).images
- assert images1.shape == (1, 512, 512, 3)
- assert images2.shape == (1, 512, 512, 3)
- assert np.abs(images1.flatten() - images2.flatten()).max() < 1e-2
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 207cff2a3c..387eb6a614 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -33,9 +33,11 @@ from diffusers import (
)
from diffusers.hooks import apply_group_offloading
from diffusers.hooks.faster_cache import FasterCacheBlockHook, FasterCacheDenoiserHook
+from diffusers.hooks.first_block_cache import FirstBlockCacheConfig
from diffusers.hooks.pyramid_attention_broadcast import PyramidAttentionBroadcastHook
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
+from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
@@ -49,6 +51,7 @@ from diffusers.utils.source_code_parsing_utils import ReturnNameVisitor
from diffusers.utils.testing_utils import (
CaptureLogger,
backend_empty_cache,
+ numpy_cosine_similarity_distance,
require_accelerate_version_greater,
require_accelerator,
require_hf_hub_version_greater,
@@ -96,6 +99,20 @@ def check_qkv_fusion_processors_exist(model):
return all(p.startswith("Fused") for p in proc_names)
+def check_qkv_fused_layers_exist(model, layer_names):
+ is_fused_submodules = []
+ for submodule in model.modules():
+ if not isinstance(submodule, AttentionModuleMixin):
+ continue
+ is_fused_attribute_set = submodule.fused_projections
+ is_fused_layer = True
+ for layer in layer_names:
+ is_fused_layer = is_fused_layer and getattr(submodule, layer, None) is not None
+ is_fused = is_fused_attribute_set and is_fused_layer
+ is_fused_submodules.append(is_fused)
+ return all(is_fused_submodules)
+
+
class SDFunctionTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
@@ -1377,7 +1394,6 @@ class PipelineTesterMixin:
for component in pipe_fp16.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
-
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
@@ -1385,18 +1401,20 @@ class PipelineTesterMixin:
# Reset generator in case it is used inside dummy inputs
if "generator" in inputs:
inputs["generator"] = self.get_generator(0)
-
output = pipe(**inputs)[0]
fp16_inputs = self.get_dummy_inputs(torch_device)
# Reset generator in case it is used inside dummy inputs
if "generator" in fp16_inputs:
fp16_inputs["generator"] = self.get_generator(0)
-
output_fp16 = pipe_fp16(**fp16_inputs)[0]
- max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
- self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ output_fp16 = output_fp16.cpu()
+
+ max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
+ assert max_diff < expected_max_diff
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
@require_accelerator
@@ -2646,7 +2664,7 @@ class FasterCacheTesterMixin:
self.faster_cache_config.current_timestep_callback = lambda: pipe.current_timestep
pipe = create_pipe()
pipe.transformer.enable_cache(self.faster_cache_config)
- output = run_forward(pipe).flatten().flatten()
+ output = run_forward(pipe).flatten()
image_slice_faster_cache_enabled = np.concatenate((output[:8], output[-8:]))
# Run inference with FasterCache disabled
@@ -2753,6 +2771,55 @@ class FasterCacheTesterMixin:
self.assertTrue(state.cache is None, "Cache should be reset to None.")
+# TODO(aryan, dhruv): the cache tester mixins should probably be rewritten so that more models can be tested out
+# of the box once there is better cache support/implementation
+class FirstBlockCacheTesterMixin:
+ # threshold is intentionally set higher than usual values since we're testing with random unconverged models
+ # that will not satisfy the expected properties of the denoiser for caching to be effective
+ first_block_cache_config = FirstBlockCacheConfig(threshold=0.8)
+
+ def test_first_block_cache_inference(self, expected_atol: float = 0.1):
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
+
+ def create_pipe():
+ torch.manual_seed(0)
+ num_layers = 2
+ components = self.get_dummy_components(num_layers=num_layers)
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+ return pipe
+
+ def run_forward(pipe):
+ torch.manual_seed(0)
+ inputs = self.get_dummy_inputs(device)
+ inputs["num_inference_steps"] = 4
+ return pipe(**inputs)[0]
+
+ # Run inference without FirstBlockCache
+ pipe = create_pipe()
+ output = run_forward(pipe).flatten()
+ original_image_slice = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache enabled
+ pipe = create_pipe()
+ pipe.transformer.enable_cache(self.first_block_cache_config)
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_enabled = np.concatenate((output[:8], output[-8:]))
+
+ # Run inference with FirstBlockCache disabled
+ pipe.transformer.disable_cache()
+ output = run_forward(pipe).flatten()
+ image_slice_fbc_disabled = np.concatenate((output[:8], output[-8:]))
+
+ assert np.allclose(original_image_slice, image_slice_fbc_enabled, atol=expected_atol), (
+ "FirstBlockCache outputs should not differ much."
+ )
+ assert np.allclose(original_image_slice, image_slice_fbc_disabled, atol=1e-4), (
+ "Outputs from normal inference and after disabling cache should not differ."
+ )
+
+
# Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used.
# This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a
# reference image.
diff --git a/tests/pipelines/text_to_video_synthesis/__init__.py b/tests/pipelines/text_to_video_synthesis/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
deleted file mode 100644
index 445f876985..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video.py
+++ /dev/null
@@ -1,231 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoSDPipeline, UNet3DConditionModel
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- load_numpy,
- numpy_cosine_similarity_distance,
- require_torch_accelerator,
- skip_mps,
- slow,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, SDFunctionTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class TextToVideoSDPipelineFastTests(PipelineTesterMixin, SDFunctionTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoSDPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(8, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=4,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=False,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=(8,),
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=4,
- intermediate_size=16,
- layer_norm_eps=1e-05,
- num_attention_heads=2,
- num_hidden_layers=2,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = TextToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
-
- image_slice = frames[0][0][-3:, -3:, -1]
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.8093, 0.2751, 0.6976, 0.5927, 0.4616, 0.4336, 0.5094, 0.5683, 0.4796])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @unittest.skipIf(torch_device != "cuda", reason="Feature isn't heavily used. Test in CUDA environment only.")
- def test_attention_slicing_forward_pass(self):
- self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3)
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@slow
-@skip_mps
-@require_torch_accelerator
-class TextToVideoSDPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_two_step_model(self):
- expected_video = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/video_2step.npy"
- )
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- assert numpy_cosine_similarity_distance(expected_video.flatten(), video_frames.flatten()) < 1e-4
-
- def test_two_step_model_with_freeu(self):
- expected_video = []
-
- pipe = TextToVideoSDPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b")
- pipe = pipe.to(torch_device)
-
- prompt = "Spiderman is surfing"
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
- video_frames = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").frames
- video = video_frames[0, 0, -3:, -3:, -1].flatten()
-
- expected_video = [0.3643, 0.3455, 0.3831, 0.3923, 0.2978, 0.3247, 0.3278, 0.3201, 0.3475]
-
- assert np.abs(expected_video - video).mean() < 5e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
deleted file mode 100644
index 8c29b27416..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import torch
-
-from diffusers import DDIMScheduler, TextToVideoZeroPipeline
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- load_pt,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..test_pipelines_common import assert_mean_pixel_difference
-
-
-@nightly
-@require_torch_accelerator
-class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_full_model(self):
- model_id = "stable-diffusion-v1-5/stable-diffusion-v1-5"
- pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(torch_device)
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- prompt = "A bear is playing a guitar on Times Square"
- result = pipe(prompt=prompt, generator=generator).images
-
- expected_result = load_pt(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/text-to-video/A bear is playing a guitar on Times Square.pt",
- weights_only=False,
- )
-
- assert_mean_pixel_difference(result, expected_result)
diff --git a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py b/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
deleted file mode 100644
index da60435d0d..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_text_to_video_zero_sdxl.py
+++ /dev/null
@@ -1,403 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import inspect
-import tempfile
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import AutoencoderKL, DDIMScheduler, TextToVideoZeroSDXLPipeline, UNet2DConditionModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- nightly,
- require_accelerate_version_greater,
- require_torch_accelerator,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-def to_np(tensor):
- if isinstance(tensor, torch.Tensor):
- tensor = tensor.detach().cpu().numpy()
-
- return tensor
-
-
-class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
- pipeline_class = TextToVideoZeroSDXLPipeline
- params = TEXT_TO_IMAGE_PARAMS
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
- generator_device = "cpu"
-
- def get_dummy_components(self, seed=0):
- torch.manual_seed(seed)
- unet = UNet2DConditionModel(
- block_out_channels=(2, 4),
- layers_per_block=2,
- sample_size=2,
- norm_num_groups=2,
- in_channels=4,
- out_channels=4,
- down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
- up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
- # SD2-specific config below
- attention_head_dim=(2, 4),
- use_linear_projection=True,
- addition_embed_type="text_time",
- addition_time_embed_dim=8,
- transformer_layers_per_block=(1, 2),
- projection_class_embeddings_input_dim=80, # 6 * 8 + 32
- cross_attention_dim=64,
- )
- scheduler = DDIMScheduler(
- num_train_timesteps=1000,
- beta_start=0.0001,
- beta_end=0.02,
- beta_schedule="linear",
- trained_betas=None,
- clip_sample=True,
- set_alpha_to_one=True,
- steps_offset=0,
- prediction_type="epsilon",
- thresholding=False,
- dynamic_thresholding_ratio=0.995,
- clip_sample_range=1.0,
- sample_max_value=1.0,
- timestep_spacing="leading",
- rescale_betas_zero_snr=False,
- )
- torch.manual_seed(seed)
- vae = AutoencoderKL(
- block_out_channels=[32, 64],
- in_channels=3,
- out_channels=3,
- down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
- up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=128,
- )
- torch.manual_seed(seed)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- # SD2-specific config below
- hidden_act="gelu",
- projection_dim=32,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
- tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_encoder_2": text_encoder_2,
- "tokenizer_2": tokenizer_2,
- "image_encoder": None,
- "feature_extractor": None,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A panda dancing in Antarctica",
- "generator": generator,
- "num_inference_steps": 5,
- "t0": 1,
- "t1": 3,
- "height": 64,
- "width": 64,
- "video_length": 3,
- "output_type": "np",
- }
- return inputs
-
- def get_generator(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- return generator
-
- def test_text_to_video_zero_sdxl(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- result = pipe(**inputs).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array(
- [0.6008109, 0.73051643, 0.51778656, 0.55817354, 0.45222935, 0.45998418, 0.57017255, 0.54874814, 0.47078788]
- )
- expected_slice2 = np.array(
- [0.6011751, 0.47420046, 0.41660714, 0.6472957, 0.41261768, 0.5438129, 0.7401535, 0.6756011, 0.53652245]
- )
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_attention_slicing_forward_pass(self):
- pass
-
- def test_cfg(self):
- sig = inspect.signature(self.pipeline_class.__call__)
- if "guidance_scale" not in sig.parameters:
- return
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
-
- inputs["guidance_scale"] = 1.0
- out_no_cfg = pipe(**inputs)[0]
-
- inputs["guidance_scale"] = 7.5
- out_cfg = pipe(**inputs)[0]
-
- assert out_cfg.shape == out_no_cfg.shape
-
- def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(self.generator_device))[0]
- output_tuple = pipe(**self.get_dummy_inputs(self.generator_device), return_dict=False)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_tuple)).max()
- self.assertLess(max_diff, expected_max_difference)
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_torch_accelerator
- def test_float16_inference(self, expected_max_diff=5e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- components = self.get_dummy_components()
- pipe_fp16 = self.pipeline_class(**components)
- pipe_fp16.to(torch_device, torch.float16)
- pipe_fp16.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- # # Reset generator in case it is used inside dummy inputs
- if "generator" in inputs:
- inputs["generator"] = self.get_generator(self.generator_device)
-
- output = pipe(**inputs)[0]
-
- fp16_inputs = self.get_dummy_inputs(self.generator_device)
- # Reset generator in case it is used inside dummy inputs
- if "generator" in fp16_inputs:
- fp16_inputs["generator"] = self.get_generator(self.generator_device)
-
- output_fp16 = pipe_fp16(**fp16_inputs)[0]
-
- max_diff = np.abs(to_np(output) - to_np(output_fp16)).max()
- self.assertLess(max_diff, expected_max_diff, "The outputs of the fp16 and fp32 pipelines are too different.")
-
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_inference_batch_single_identical(self):
- pass
-
- @require_torch_accelerator
- @require_accelerate_version_greater("0.17.0")
- def test_model_cpu_offload_forward_pass(self, expected_max_diff=2e-4):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_without_offload = pipe(**inputs)[0]
-
- pipe.enable_model_cpu_offload(device=torch_device)
- inputs = self.get_dummy_inputs(self.generator_device)
- output_with_offload = pipe(**inputs)[0]
-
- max_diff = np.abs(to_np(output_with_offload) - to_np(output_without_offload)).max()
- self.assertLess(max_diff, expected_max_diff, "CPU offloading should not affect the inference results")
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_pipeline_call_signature(self):
- pass
-
- @unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
- @require_torch_accelerator
- def test_save_load_float16(self, expected_max_diff=1e-2):
- components = self.get_dummy_components()
- for name, module in components.items():
- if hasattr(module, "half"):
- components[name] = module.to(torch_device).half()
-
- pipe = self.pipeline_class(**components)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output = pipe(**inputs)[0]
-
- with tempfile.TemporaryDirectory() as tmpdir:
- pipe.save_pretrained(tmpdir)
- pipe_loaded = self.pipeline_class.from_pretrained(tmpdir, torch_dtype=torch.float16)
- pipe_loaded.to(torch_device)
- pipe_loaded.set_progress_bar_config(disable=None)
-
- for name, component in pipe_loaded.components.items():
- if hasattr(component, "dtype"):
- self.assertTrue(
- component.dtype == torch.float16,
- f"`{name}.dtype` switched from `float16` to {component.dtype} after loading.",
- )
-
- inputs = self.get_dummy_inputs(self.generator_device)
- output_loaded = pipe_loaded(**inputs)[0]
- max_diff = np.abs(to_np(output) - to_np(output_loaded)).max()
- self.assertLess(
- max_diff, expected_max_diff, "The output of the fp16 pipeline changed after saving and loading."
- )
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_local(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_save_load_optional_components(self):
- pass
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_sequential_cpu_offload_forward_pass(self):
- pass
-
- @require_torch_accelerator
- def test_to_device(self):
- components = self.get_dummy_components()
- pipe = self.pipeline_class(**components)
- pipe.set_progress_bar_config(disable=None)
-
- pipe.to("cpu")
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == "cpu" for device in model_devices))
-
- output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(output_cpu).sum() == 0)
-
- pipe.to(torch_device)
- model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
- self.assertTrue(all(device == torch_device for device in model_devices))
-
- output_device = pipe(**self.get_dummy_inputs("cpu"))[0] # generator set to cpu
- self.assertTrue(np.isnan(to_np(output_device)).sum() == 0)
-
- @unittest.skip(
- reason="Cannot call `set_default_attn_processor` as this pipeline uses a specific attention processor."
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- pass
-
-
-@nightly
-@require_torch_accelerator
-class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_full_model(self):
- model_id = "stabilityai/stable-diffusion-xl-base-1.0"
- pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
- model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
- )
- pipe.enable_model_cpu_offload()
- pipe.enable_vae_slicing()
-
- pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
- generator = torch.Generator(device="cpu").manual_seed(0)
-
- prompt = "A panda dancing in Antarctica"
- result = pipe(prompt=prompt, generator=generator).images
-
- first_frame_slice = result[0, -3:, -3:, -1]
- last_frame_slice = result[-1, -3:, -3:, 0]
-
- expected_slice1 = np.array([0.57, 0.57, 0.57, 0.57, 0.57, 0.56, 0.55, 0.56, 0.56])
- expected_slice2 = np.array([0.54, 0.53, 0.53, 0.53, 0.53, 0.52, 0.53, 0.53, 0.53])
-
- assert np.abs(first_frame_slice.flatten() - expected_slice1).max() < 1e-2
- assert np.abs(last_frame_slice.flatten() - expected_slice2).max() < 1e-2
diff --git a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py b/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
deleted file mode 100644
index 2efef3d640..0000000000
--- a/tests/pipelines/text_to_video_synthesis/test_video_to_video.py
+++ /dev/null
@@ -1,229 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import (
- AutoencoderKL,
- DDIMScheduler,
- UNet3DConditionModel,
- VideoToVideoSDPipeline,
-)
-from diffusers.utils import is_xformers_available
-from diffusers.utils.testing_utils import (
- enable_full_determinism,
- floats_tensor,
- is_flaky,
- nightly,
- numpy_cosine_similarity_distance,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import (
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-@skip_mps
-class VideoToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = VideoToVideoSDPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"video"}) - {"image", "width", "height"}
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS.union({"video"}) - {"image"}
- required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
- test_attention_slicing = False
-
- # No `output_type`.
- required_optional_params = frozenset(
- [
- "num_inference_steps",
- "generator",
- "latents",
- "return_dict",
- "callback",
- "callback_steps",
- ]
- )
-
- def get_dummy_components(self):
- torch.manual_seed(0)
- unet = UNet3DConditionModel(
- block_out_channels=(4, 8),
- layers_per_block=1,
- sample_size=32,
- in_channels=4,
- out_channels=4,
- down_block_types=("CrossAttnDownBlock3D", "DownBlock3D"),
- up_block_types=("UpBlock3D", "CrossAttnUpBlock3D"),
- cross_attention_dim=32,
- attention_head_dim=4,
- norm_num_groups=2,
- )
- scheduler = DDIMScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- clip_sample=True,
- set_alpha_to_one=False,
- )
- torch.manual_seed(0)
- vae = AutoencoderKL(
- block_out_channels=[
- 8,
- ],
- in_channels=3,
- out_channels=3,
- down_block_types=[
- "DownEncoderBlock2D",
- ],
- up_block_types=["UpDecoderBlock2D"],
- latent_channels=4,
- sample_size=32,
- norm_num_groups=2,
- )
- torch.manual_seed(0)
- text_encoder_config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=32,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- hidden_act="gelu",
- projection_dim=512,
- )
- text_encoder = CLIPTextModel(text_encoder_config)
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- components = {
- "unet": unet,
- "scheduler": scheduler,
- "vae": vae,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- }
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- # 3 frames
- video = floats_tensor((1, 3, 3, 32, 32), rng=random.Random(seed)).to(device)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "A painting of a squirrel eating a burger",
- "video": video,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "pt",
- }
- return inputs
-
- def test_text_to_video_default_case(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- sd_pipe = VideoToVideoSDPipeline(**components)
- sd_pipe = sd_pipe.to(device)
- sd_pipe.set_progress_bar_config(disable=None)
-
- inputs = self.get_dummy_inputs(device)
- inputs["output_type"] = "np"
- frames = sd_pipe(**inputs).frames
- image_slice = frames[0][0][-3:, -3:, -1]
-
- assert frames[0][0].shape == (32, 32, 3)
- expected_slice = np.array([0.6391, 0.5350, 0.5202, 0.5521, 0.5453, 0.5393, 0.6652, 0.5270, 0.5185])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
-
- @is_flaky()
- def test_save_load_optional_components(self):
- super().test_save_load_optional_components(expected_max_difference=0.001)
-
- @is_flaky()
- def test_dict_tuple_outputs_equivalent(self):
- super().test_dict_tuple_outputs_equivalent()
-
- @is_flaky()
- def test_save_load_local(self):
- super().test_save_load_local()
-
- @unittest.skipIf(
- torch_device != "cuda" or not is_xformers_available(),
- reason="XFormers attention is only available with CUDA and `xformers` installed",
- )
- def test_xformers_attention_forwardGenerator_pass(self):
- self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=5e-3)
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_consistent(self):
- pass
-
- # (todo): sayakpaul
- @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.")
- def test_inference_batch_single_identical(self):
- pass
-
- @unittest.skip(reason="`num_images_per_prompt` argument is not supported for this pipeline.")
- def test_num_images_per_prompt(self):
- pass
-
- def test_encode_prompt_works_in_isolation(self):
- extra_required_param_value_dict = {
- "device": torch.device(torch_device).type,
- "num_images_per_prompt": 1,
- "do_classifier_free_guidance": self.get_dummy_inputs(device=torch_device).get("guidance_scale", 1.0) > 1.0,
- }
- return super().test_encode_prompt_works_in_isolation(extra_required_param_value_dict)
-
-
-@nightly
-@skip_mps
-class VideoToVideoSDPipelineSlowTests(unittest.TestCase):
- def test_two_step_model(self):
- pipe = VideoToVideoSDPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
- pipe.enable_model_cpu_offload()
-
- # 10 frames
- generator = torch.Generator(device="cpu").manual_seed(0)
- video = torch.randn((1, 10, 3, 320, 576), generator=generator)
-
- prompt = "Spiderman is surfing"
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- video_frames = pipe(prompt, video=video, generator=generator, num_inference_steps=3, output_type="np").frames
-
- expected_array = np.array(
- [0.17114258, 0.13720703, 0.08886719, 0.14819336, 0.1730957, 0.24584961, 0.22021484, 0.35180664, 0.2607422]
- )
- output_array = video_frames[0, 0, :3, :3, 0].flatten()
- assert numpy_cosine_similarity_distance(expected_array, output_array) < 1e-3
diff --git a/tests/pipelines/unclip/__init__.py b/tests/pipelines/unclip/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/unclip/test_unclip.py b/tests/pipelines/unclip/test_unclip.py
deleted file mode 100644
index 4a970a4f6f..0000000000
--- a/tests/pipelines/unclip/test_unclip.py
+++ /dev/null
@@ -1,523 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
-
-from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- backend_max_memory_allocated,
- backend_reset_max_memory_allocated,
- backend_reset_peak_memory_stats,
- enable_full_determinism,
- load_numpy,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPPipeline
- params = TEXT_TO_IMAGE_PARAMS - {
- "negative_prompt",
- "height",
- "width",
- "negative_prompt_embeds",
- "guidance_scale",
- "prompt_embeds",
- "cross_attention_kwargs",
- }
- batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
- required_optional_params = [
- "generator",
- "return_dict",
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "num_attention_heads": 2,
- "attention_head_dim": 12,
- "embedding_dim": self.text_embedder_hidden_size,
- "num_layers": 1,
- }
-
- model = PriorTransformer(**model_kwargs)
- return model
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- prior_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="sample",
- num_train_timesteps=1000,
- clip_sample_range=5.0,
- )
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- components = {
- "prior": prior,
- "decoder": decoder,
- "text_proj": text_proj,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "prior_scheduler": prior_scheduler,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_num_inference_steps": 2,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_unclip(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(
- **self.get_dummy_inputs(device),
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9988,
- 0.0028,
- 0.9997,
- 0.9984,
- 0.9965,
- 0.0029,
- 0.9986,
- 0.0025,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_text_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- prior = components["prior"]
- decoder = components["decoder"]
- super_res_first = components["super_res_first"]
- tokenizer = components["tokenizer"]
- text_encoder = components["text_encoder"]
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = prior.dtype
- batch_size = 1
-
- shape = (batch_size, prior.config.embedding_dim)
- prior_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
- shape = (batch_size, decoder.config.in_channels, decoder.config.sample_size, decoder.config.sample_size)
- generator = torch.Generator(device=device).manual_seed(0)
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- super_res_first.config.in_channels // 2,
- super_res_first.config.sample_size,
- super_res_first.config.sample_size,
- )
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipe.set_progress_bar_config(disable=None)
-
- prompt = "this is a prompt example"
-
- generator = torch.Generator(device=device).manual_seed(0)
- output = pipe(
- [prompt],
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- output_type="np",
- )
- image = output.images
-
- text_inputs = tokenizer(
- prompt,
- padding="max_length",
- max_length=tokenizer.model_max_length,
- return_tensors="pt",
- )
- text_model_output = text_encoder(text_inputs.input_ids)
- text_attention_mask = text_inputs.attention_mask
-
- generator = torch.Generator(device=device).manual_seed(0)
- image_from_text = pipe(
- generator=generator,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- prior_latents=prior_latents,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- text_model_output=text_model_output,
- text_attention_mask=text_attention_mask,
- output_type="np",
- )[0]
-
- # make sure passing text embeddings manually is identical
- assert np.abs(image - image_from_text).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- self._test_attention_slicing_forward_pass(test_max_difference=test_max_difference, expected_max_diff=0.01)
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=9.8e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "prior_num_inference_steps",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=5e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large differences in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-class UnCLIPPipelineCPUIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_karlo_cpu_fp32(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_cpu.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.manual_seed(0)
- output = pipeline(
- "horse",
- num_images_per_prompt=1,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
- assert np.abs(expected_image - image).max() < 1e-1
-
-
-@nightly
-@require_torch_accelerator
-class UnCLIPPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_karlo(self):
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_horse_fp16.npy"
- )
-
- pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- "horse",
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image)
-
- def test_unclip_pipeline_with_sequential_cpu_offloading(self):
- backend_empty_cache(torch_device)
- backend_reset_max_memory_allocated(torch_device)
- backend_reset_peak_memory_stats(torch_device)
-
- pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
- pipe.enable_sequential_cpu_offload()
-
- _ = pipe(
- "horse",
- num_images_per_prompt=1,
- prior_num_inference_steps=2,
- decoder_num_inference_steps=2,
- super_res_num_inference_steps=2,
- output_type="np",
- )
-
- mem_bytes = backend_max_memory_allocated(torch_device)
- # make sure that less than 7 GB is allocated
- assert mem_bytes < 7 * 10**9
diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py
deleted file mode 100644
index 15733513a5..0000000000
--- a/tests/pipelines/unclip/test_unclip_image_variation.py
+++ /dev/null
@@ -1,540 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from transformers import (
- CLIPImageProcessor,
- CLIPTextConfig,
- CLIPTextModelWithProjection,
- CLIPTokenizer,
- CLIPVisionConfig,
- CLIPVisionModelWithProjection,
-)
-
-from diffusers import (
- DiffusionPipeline,
- UnCLIPImageVariationPipeline,
- UnCLIPScheduler,
- UNet2DConditionModel,
- UNet2DModel,
-)
-from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- load_numpy,
- nightly,
- require_torch_accelerator,
- skip_mps,
- torch_device,
-)
-
-from ..pipeline_params import IMAGE_VARIATION_BATCH_PARAMS, IMAGE_VARIATION_PARAMS
-from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference
-
-
-enable_full_determinism()
-
-
-class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = UnCLIPImageVariationPipeline
- params = IMAGE_VARIATION_PARAMS - {"height", "width", "guidance_scale"}
- batch_params = IMAGE_VARIATION_BATCH_PARAMS
-
- required_optional_params = [
- "generator",
- "return_dict",
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- test_xformers_attention = False
- supports_dduf = False
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def cross_attention_dim(self):
- return 100
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModelWithProjection(config)
-
- @property
- def dummy_image_encoder(self):
- torch.manual_seed(0)
- config = CLIPVisionConfig(
- hidden_size=self.text_embedder_hidden_size,
- projection_dim=self.text_embedder_hidden_size,
- num_hidden_layers=5,
- num_attention_heads=4,
- image_size=32,
- intermediate_size=37,
- patch_size=1,
- )
- return CLIPVisionModelWithProjection(config)
-
- @property
- def dummy_text_proj(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "clip_embeddings_dim": self.text_embedder_hidden_size,
- "time_embed_dim": self.time_embed_dim,
- "cross_attention_dim": self.cross_attention_dim,
- }
-
- model = UnCLIPTextProjModel(**model_kwargs)
- return model
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "sample_size": 32,
- # RGB in channels
- "in_channels": 3,
- # Out channels is double in channels because predicts mean and variance
- "out_channels": 6,
- "down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
- "up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
- "mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "layers_per_block": 1,
- "cross_attention_dim": self.cross_attention_dim,
- "attention_head_dim": 4,
- "resnet_time_scale_shift": "scale_shift",
- "class_embed_type": "identity",
- }
-
- model = UNet2DConditionModel(**model_kwargs)
- return model
-
- @property
- def dummy_super_res_kwargs(self):
- return {
- "sample_size": 64,
- "layers_per_block": 1,
- "down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
- "up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
- "block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
- "in_channels": 6,
- "out_channels": 3,
- }
-
- @property
- def dummy_super_res_first(self):
- torch.manual_seed(0)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- @property
- def dummy_super_res_last(self):
- # seeded differently to get different unet than `self.dummy_super_res_first`
- torch.manual_seed(1)
-
- model = UNet2DModel(**self.dummy_super_res_kwargs)
- return model
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_proj = self.dummy_text_proj
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- super_res_first = self.dummy_super_res_first
- super_res_last = self.dummy_super_res_last
-
- decoder_scheduler = UnCLIPScheduler(
- variance_type="learned_range",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- super_res_scheduler = UnCLIPScheduler(
- variance_type="fixed_small_log",
- prediction_type="epsilon",
- num_train_timesteps=1000,
- )
-
- feature_extractor = CLIPImageProcessor(crop_size=32, size=32)
-
- image_encoder = self.dummy_image_encoder
-
- return {
- "decoder": decoder,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "text_proj": text_proj,
- "feature_extractor": feature_extractor,
- "image_encoder": image_encoder,
- "super_res_first": super_res_first,
- "super_res_last": super_res_last,
- "decoder_scheduler": decoder_scheduler,
- "super_res_scheduler": super_res_scheduler,
- }
-
- def get_dummy_inputs(self, device, seed=0, pil_image=True):
- input_image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- if pil_image:
- input_image = input_image * 0.5 + 0.5
- input_image = input_image.clamp(0, 1)
- input_image = input_image.cpu().permute(0, 2, 3, 1).float().numpy()
- input_image = DiffusionPipeline.numpy_to_pil(input_image)[0]
-
- return {
- "image": input_image,
- "generator": generator,
- "decoder_num_inference_steps": 2,
- "super_res_num_inference_steps": 2,
- "output_type": "np",
- }
-
- def test_unclip_image_variation_input_tensor(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.0002,
- 0.9997,
- 0.9997,
- 0.9969,
- 0.0023,
- 0.9997,
- 0.9969,
- 0.9970,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_image(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.9997, 0.0003, 0.9997, 0.9997, 0.9970, 0.0024, 0.9997, 0.9971, 0.9971])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_image_variation_input_list_images(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- pipeline_inputs["image"] = [
- pipeline_inputs["image"],
- pipeline_inputs["image"],
- ]
-
- output = pipe(**pipeline_inputs)
- image = output.images
-
- tuple_pipeline_inputs = self.get_dummy_inputs(device, pil_image=True)
- tuple_pipeline_inputs["image"] = [
- tuple_pipeline_inputs["image"],
- tuple_pipeline_inputs["image"],
- ]
-
- image_from_tuple = pipe(
- **tuple_pipeline_inputs,
- return_dict=False,
- )[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (2, 64, 64, 3)
-
- expected_slice = np.array(
- [
- 0.9997,
- 0.9989,
- 0.0008,
- 0.0021,
- 0.9960,
- 0.0018,
- 0.0014,
- 0.0002,
- 0.9933,
- ]
- )
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- def test_unclip_passed_image_embed(self):
- device = torch.device("cpu")
-
- class DummyScheduler:
- init_noise_sigma = 1
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device=device).manual_seed(0)
- dtype = pipe.decoder.dtype
- batch_size = 1
-
- shape = (
- batch_size,
- pipe.decoder.config.in_channels,
- pipe.decoder.config.sample_size,
- pipe.decoder.config.sample_size,
- )
- decoder_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- shape = (
- batch_size,
- pipe.super_res_first.config.in_channels // 2,
- pipe.super_res_first.config.sample_size,
- pipe.super_res_first.config.sample_size,
- )
- generator = torch.Generator(device=device).manual_seed(0)
- super_res_latents = pipe.prepare_latents(
- shape, dtype=dtype, device=device, generator=generator, latents=None, scheduler=DummyScheduler()
- )
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
-
- img_out_1 = pipe(
- **pipeline_inputs, decoder_latents=decoder_latents, super_res_latents=super_res_latents
- ).images
-
- pipeline_inputs = self.get_dummy_inputs(device, pil_image=False)
- # Don't pass image, instead pass embedding
- image = pipeline_inputs.pop("image")
- image_embeddings = pipe.image_encoder(image).image_embeds
-
- img_out_2 = pipe(
- **pipeline_inputs,
- decoder_latents=decoder_latents,
- super_res_latents=super_res_latents,
- image_embeddings=image_embeddings,
- ).images
-
- # make sure passing text embeddings manually is identical
- assert np.abs(img_out_1 - img_out_2).max() < 1e-4
-
- # Overriding PipelineTesterMixin::test_attention_slicing_forward_pass
- # because UnCLIP GPU undeterminism requires a looser check.
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
-
- # Check is relaxed because there is not a torch 2.0 sliced attention added kv processor
- expected_max_diff = 1e-2
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference, expected_max_diff=expected_max_diff
- )
-
- # Overriding PipelineTesterMixin::test_inference_batch_single_identical
- # because UnCLIP undeterminism requires a looser check.
- @unittest.skip("UnCLIP produces very large differences. Test is not useful.")
- @skip_mps
- def test_inference_batch_single_identical(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
- self._test_inference_batch_single_identical(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
- )
-
- def test_inference_batch_consistent(self):
- additional_params_copy_to_batched_inputs = [
- "decoder_num_inference_steps",
- "super_res_num_inference_steps",
- ]
-
- if torch_device == "mps":
- # TODO: MPS errors with larger batch sizes
- batch_sizes = [2, 3]
- self._test_inference_batch_consistent(
- batch_sizes=batch_sizes,
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs,
- )
- else:
- self._test_inference_batch_consistent(
- additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs
- )
-
- @skip_mps
- def test_dict_tuple_outputs_equivalent(self):
- return super().test_dict_tuple_outputs_equivalent()
-
- @unittest.skip("UnCLIP produces very large difference. Test is not useful.")
- @skip_mps
- def test_save_load_local(self):
- return super().test_save_load_local(expected_max_difference=4e-3)
-
- @skip_mps
- def test_save_load_optional_components(self):
- return super().test_save_load_optional_components()
-
- @unittest.skip("UnCLIP produces very large difference in fp16 vs fp32. Test is not useful.")
- def test_float16_inference(self):
- super().test_float16_inference(expected_max_diff=1.0)
-
-
-@nightly
-@require_torch_accelerator
-class UnCLIPImageVariationPipelineIntegrationTests(unittest.TestCase):
- def setUp(self):
- # clean up the VRAM before each test
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- # clean up the VRAM after each test
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def test_unclip_image_variation_karlo(self):
- input_image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unclip/cat.png"
- )
- expected_image = load_numpy(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
- "/unclip/karlo_v1_alpha_cat_variation_fp16.npy"
- )
-
- pipeline = UnCLIPImageVariationPipeline.from_pretrained(
- "kakaobrain/karlo-v1-alpha-image-variations", torch_dtype=torch.float16
- )
- pipeline = pipeline.to(torch_device)
- pipeline.set_progress_bar_config(disable=None)
-
- generator = torch.Generator(device="cpu").manual_seed(0)
- output = pipeline(
- input_image,
- generator=generator,
- output_type="np",
- )
-
- image = output.images[0]
-
- assert image.shape == (256, 256, 3)
-
- assert_mean_pixel_difference(image, expected_image, 15)
diff --git a/tests/pipelines/unidiffuser/__init__.py b/tests/pipelines/unidiffuser/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py
deleted file mode 100644
index dccb1a8500..0000000000
--- a/tests/pipelines/unidiffuser/test_unidiffuser.py
+++ /dev/null
@@ -1,764 +0,0 @@
-import gc
-import random
-import unittest
-
-import numpy as np
-import torch
-from PIL import Image
-from transformers import (
- CLIPImageProcessor,
- CLIPTextModel,
- CLIPTokenizer,
- CLIPVisionModelWithProjection,
- GPT2Tokenizer,
-)
-
-from diffusers import (
- AutoencoderKL,
- DPMSolverMultistepScheduler,
- UniDiffuserModel,
- UniDiffuserPipeline,
- UniDiffuserTextDecoder,
-)
-from diffusers.utils.testing_utils import (
- backend_empty_cache,
- enable_full_determinism,
- floats_tensor,
- load_image,
- nightly,
- require_torch_accelerator,
- torch_device,
-)
-from diffusers.utils.torch_utils import randn_tensor
-
-from ..pipeline_params import (
- IMAGE_TO_IMAGE_IMAGE_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
- TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
-)
-from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class UniDiffuserPipelineFastTests(
- PipelineTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
-):
- pipeline_class = UniDiffuserPipeline
- params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS
- batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
- image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
- # vae_latents, not latents, is the argument that corresponds to VAE latent inputs
- image_latents_params = frozenset(["vae_latents"])
-
- supports_dduf = False
-
- def get_dummy_components(self):
- unet = UniDiffuserModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="unet",
- )
-
- scheduler = DPMSolverMultistepScheduler(
- beta_start=0.00085,
- beta_end=0.012,
- beta_schedule="scaled_linear",
- solver_order=3,
- )
-
- vae = AutoencoderKL.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="vae",
- )
-
- text_encoder = CLIPTextModel.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_encoder",
- )
- clip_tokenizer = CLIPTokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="clip_tokenizer",
- )
-
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="image_encoder",
- )
- # From the Stable Diffusion Image Variation pipeline tests
- clip_image_processor = CLIPImageProcessor(crop_size=32, size=32)
- # image_processor = CLIPImageProcessor.from_pretrained("hf-internal-testing/tiny-random-clip")
-
- text_tokenizer = GPT2Tokenizer.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_tokenizer",
- )
- text_decoder = UniDiffuserTextDecoder.from_pretrained(
- "hf-internal-testing/unidiffuser-diffusers-test",
- subfolder="text_decoder",
- )
-
- components = {
- "vae": vae,
- "text_encoder": text_encoder,
- "image_encoder": image_encoder,
- "clip_image_processor": clip_image_processor,
- "clip_tokenizer": clip_tokenizer,
- "text_decoder": text_decoder,
- "text_tokenizer": text_tokenizer,
- "unet": unet,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- image = image.cpu().permute(0, 2, 3, 1)[0]
- image = Image.fromarray(np.uint8(image)).convert("RGB")
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- }
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- generator = torch.Generator(device=device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 32), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 16, 16), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 32), generator=generator, device=device, dtype=torch.float32)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def get_dummy_inputs_with_latents(self, device, seed=0):
- # image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
- # image = image.cpu().permute(0, 2, 3, 1)[0]
- # image = Image.fromarray(np.uint8(image)).convert("RGB")
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg",
- )
- image = image.resize((32, 32))
- latents = self.get_fixed_latents(device, seed=seed)
-
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
-
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 2,
- "guidance_scale": 6.0,
- "output_type": "np",
- "prompt_latents": latents.get("prompt_latents"),
- "vae_latents": latents.get("vae_latents"),
- "clip_latents": latents.get("clip_latents"),
- }
- return inputs
-
- def test_dict_tuple_outputs_equivalent(self):
- expected_slice = None
- if torch_device == "cpu":
- expected_slice = np.array([0.7489, 0.3722, 0.4475, 0.5630, 0.5923, 0.4992, 0.3936, 0.5844, 0.4975])
- super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
-
- def test_unidiffuser_default_joint_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_no_cfg_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- # Set guidance scale to 1.0 to turn off CFG
- inputs["guidance_scale"] = 1.0
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_image_0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_image_mode()
- assert unidiffuser_pipe.mode == "img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5760, 0.6270, 0.6571, 0.4966, 0.4638, 0.5663, 0.5254, 0.5068, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img'
- unidiffuser_pipe.set_text_mode()
- assert unidiffuser_pipe.mode == "text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete prompt and image for unconditional ("marginal") text generation.
- del inputs["prompt"]
- del inputs["image"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_img2text_v0(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_joint_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5760, 0.6270, 0.6571, 0.4965, 0.4638, 0.5663, 0.5254, 0.5068, 0.5716])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.5758, 0.6269, 0.6570, 0.4967, 0.4639, 0.5664, 0.5257, 0.5067, 0.5715])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
-
- def test_unidiffuser_default_img2text_v1(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained("hf-internal-testing/unidiffuser-test-v1")
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = " no no no "
- assert text[0][:10] == expected_text_prefix
-
- def test_unidiffuser_text2img_multiple_images(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_unidiffuser_text2img_multiple_images_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete image for text-conditioned image generation
- del inputs["image"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- image = unidiffuser_pipe(**inputs).images
- assert image.shape == (2, 32, 32, 3)
-
- def test_unidiffuser_img2text_multiple_prompts_with_latents(self):
- device = "cpu" # ensure determinism for the device-dependent torch.Generator
- components = self.get_dummy_components()
- unidiffuser_pipe = UniDiffuserPipeline(**components)
- unidiffuser_pipe = unidiffuser_pipe.to(device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(device)
- # Delete text for image-conditioned text generation
- del inputs["prompt"]
- inputs["num_images_per_prompt"] = 2
- inputs["num_prompts_per_image"] = 3
- text = unidiffuser_pipe(**inputs).text
-
- assert len(text) == 3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=2e-4)
-
- @require_torch_accelerator
- def test_unidiffuser_default_joint_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'joint'
- unidiffuser_pipe.set_joint_mode()
- assert unidiffuser_pipe.mode == "joint"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5049, 0.5498, 0.5854, 0.3052, 0.4460, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @require_torch_accelerator
- def test_unidiffuser_default_text2img_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'text2img'
- unidiffuser_pipe.set_text_to_image_mode()
- assert unidiffuser_pipe.mode == "text2img"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["image"]
- inputs["data_type"] = 1
- sample = unidiffuser_pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 32, 32, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.5054, 0.5498, 0.5854, 0.3052, 0.4458, 0.6489, 0.5122, 0.4810, 0.6138])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-3
-
- @require_torch_accelerator
- def test_unidiffuser_default_img2text_v1_fp16(self):
- unidiffuser_pipe = UniDiffuserPipeline.from_pretrained(
- "hf-internal-testing/unidiffuser-test-v1", torch_dtype=torch.float16
- )
- unidiffuser_pipe = unidiffuser_pipe.to(torch_device)
- unidiffuser_pipe.set_progress_bar_config(disable=None)
-
- # Set mode to 'img2text'
- unidiffuser_pipe.set_image_to_text_mode()
- assert unidiffuser_pipe.mode == "img2text"
-
- inputs = self.get_dummy_inputs_with_latents(torch_device)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- inputs["data_type"] = 1
- text = unidiffuser_pipe(**inputs).text
-
- expected_text_prefix = '" This This'
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- @unittest.skip(
- "Test not supported because it has a bunch of direct configs at init and also, this pipeline isn't used that much now."
- )
- def test_encode_prompt_works_in_isolation():
- pass
-
-
-@nightly
-@require_torch_accelerator
-class UniDiffuserPipelineSlowTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 1e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1")
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
-
-@nightly
-@require_torch_accelerator
-class UniDiffuserPipelineNightlyTests(unittest.TestCase):
- def setUp(self):
- super().setUp()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def tearDown(self):
- super().tearDown()
- gc.collect()
- backend_empty_cache(torch_device)
-
- def get_inputs(self, device, seed=0, generate_latents=False):
- generator = torch.manual_seed(seed)
- image = load_image(
- "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/unidiffuser/unidiffuser_example_image.jpg"
- )
- inputs = {
- "prompt": "an elephant under the sea",
- "image": image,
- "generator": generator,
- "num_inference_steps": 3,
- "guidance_scale": 8.0,
- "output_type": "np",
- }
- if generate_latents:
- latents = self.get_fixed_latents(device, seed=seed)
- for latent_name, latent_tensor in latents.items():
- inputs[latent_name] = latent_tensor
- return inputs
-
- def get_fixed_latents(self, device, seed=0):
- if isinstance(device, str):
- device = torch.device(device)
- latent_device = torch.device("cpu")
- generator = torch.Generator(device=latent_device).manual_seed(seed)
- # Hardcode the shapes for now.
- prompt_latents = randn_tensor((1, 77, 768), generator=generator, device=device, dtype=torch.float32)
- vae_latents = randn_tensor((1, 4, 64, 64), generator=generator, device=device, dtype=torch.float32)
- clip_latents = randn_tensor((1, 1, 512), generator=generator, device=device, dtype=torch.float32)
-
- # Move latents onto desired device.
- prompt_latents = prompt_latents.to(device)
- vae_latents = vae_latents.to(device)
- clip_latents = clip_latents.to(device)
-
- latents = {
- "prompt_latents": prompt_latents,
- "vae_latents": vae_latents,
- "clip_latents": clip_latents,
- }
- return latents
-
- def test_unidiffuser_default_joint_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- # inputs = self.get_dummy_inputs(device)
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- # Delete prompt and image for joint inference.
- del inputs["prompt"]
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- text = sample.text
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_img_slice = np.array([0.2402, 0.2375, 0.2285, 0.2378, 0.2407, 0.2263, 0.2354, 0.2307, 0.2520])
- assert np.abs(image_slice.flatten() - expected_img_slice).max() < 2e-1
-
- expected_text_prefix = "a living room"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
-
- def test_unidiffuser_default_text2img_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["image"]
- sample = pipe(**inputs)
- image = sample.images
- assert image.shape == (1, 512, 512, 3)
-
- image_slice = image[0, -3:, -3:, -1]
- expected_slice = np.array([0.0242, 0.0103, 0.0022, 0.0129, 0.0000, 0.0090, 0.0376, 0.0508, 0.0005])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-1
-
- def test_unidiffuser_default_img2text_v1_fp16(self):
- pipe = UniDiffuserPipeline.from_pretrained("thu-ml/unidiffuser-v1", torch_dtype=torch.float16)
- pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
- pipe.enable_attention_slicing()
-
- inputs = self.get_inputs(device=torch_device, generate_latents=True)
- del inputs["prompt"]
- sample = pipe(**inputs)
- text = sample.text
-
- expected_text_prefix = "An astronaut"
- assert text[0][: len(expected_text_prefix)] == expected_text_prefix
diff --git a/tests/pipelines/wan/test_wan.py b/tests/pipelines/wan/test_wan.py
index 842b9d19b3..a7e4e27813 100644
--- a/tests/pipelines/wan/test_wan.py
+++ b/tests/pipelines/wan/test_wan.py
@@ -15,7 +15,6 @@
import gc
import unittest
-import numpy as np
import torch
from transformers import AutoTokenizer, T5EncoderModel
@@ -29,9 +28,7 @@ from diffusers.utils.testing_utils import (
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
-from ..test_pipelines_common import (
- PipelineTesterMixin,
-)
+from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
@@ -88,12 +85,29 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
rope_max_seq_len=32,
)
+ torch.manual_seed(0)
+ transformer_2 = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=16,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ )
+
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
+ "transformer_2": transformer_2,
}
return components
@@ -127,11 +141,15 @@ class WanPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.452, 0.4485, 0.4534, 0.4524, 0.4529, 0.454, 0.453, 0.5127, 0.5326, 0.5204, 0.5253, 0.5439, 0.5424, 0.5133, 0.5078])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
diff --git a/tests/pipelines/wan/test_wan_image_to_video.py b/tests/pipelines/wan/test_wan_image_to_video.py
index 22dfef2eb0..c693f4fcb2 100644
--- a/tests/pipelines/wan/test_wan_image_to_video.py
+++ b/tests/pipelines/wan/test_wan_image_to_video.py
@@ -14,7 +14,6 @@
import unittest
-import numpy as np
import torch
from PIL import Image
from transformers import (
@@ -87,6 +86,23 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image_dim=4,
)
+ torch.manual_seed(0)
+ transformer_2 = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ )
+
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
@@ -110,6 +126,7 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
+ "transformer_2": transformer_2,
}
return components
@@ -147,11 +164,15 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (9, 3, 16, 16))
- expected_video = torch.randn(9, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4525, 0.4525, 0.4497, 0.4536, 0.452, 0.4529, 0.454, 0.4535, 0.5072, 0.5527, 0.5165, 0.5244, 0.5481, 0.5282, 0.5208, 0.5214])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
@@ -161,8 +182,32 @@ class WanImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference_batch_single_identical(self):
pass
+ @unittest.skip(
+ "TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
+ )
+ def test_save_load_optional_components(self):
+ pass
+
+
+class WanFLFToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WanImageToVideoPipeline
+ params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "height", "width"}
+ batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
+ image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+ test_xformers_attention = False
+ supports_dduf = False
-class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
@@ -197,6 +242,24 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
pos_embed_seq_len=2 * (4 * 4 + 1),
)
+ torch.manual_seed(0)
+ transformer_2 = WanTransformer3DModel(
+ patch_size=(1, 2, 2),
+ num_attention_heads=2,
+ attention_head_dim=12,
+ in_channels=36,
+ out_channels=16,
+ text_dim=32,
+ freq_dim=256,
+ ffn_dim=32,
+ num_layers=2,
+ cross_attn_norm=True,
+ qk_norm="rms_norm_across_heads",
+ rope_max_seq_len=32,
+ image_dim=4,
+ pos_embed_seq_len=2 * (4 * 4 + 1),
+ )
+
torch.manual_seed(0)
image_encoder_config = CLIPVisionConfig(
hidden_size=4,
@@ -220,6 +283,7 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
"tokenizer": tokenizer,
"image_encoder": image_encoder,
"image_processor": image_processor,
+ "transformer_2": transformer_2,
}
return components
@@ -247,3 +311,38 @@ class WanFLFToVideoPipelineFastTests(WanImageToVideoPipelineFastTests):
"output_type": "pt",
}
return inputs
+
+ def test_inference(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+ pipe = self.pipeline_class(**components)
+ pipe.to(device)
+ pipe.set_progress_bar_config(disable=None)
+
+ inputs = self.get_dummy_inputs(device)
+ video = pipe(**inputs).frames
+ generated_video = video[0]
+ self.assertEqual(generated_video.shape, (9, 3, 16, 16))
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4531, 0.4527, 0.4498, 0.4542, 0.4526, 0.4527, 0.4534, 0.4534, 0.5061, 0.5185, 0.5283, 0.5181, 0.5309, 0.5365, 0.5113, 0.5244])
+ # fmt: on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
+
+ @unittest.skip("Test not supported")
+ def test_attention_slicing_forward_pass(self):
+ pass
+
+ @unittest.skip("TODO: revisit failing as it requires a very high threshold to pass")
+ def test_inference_batch_single_identical(self):
+ pass
+
+ @unittest.skip(
+ "TODO: refactor this test: one component can be optional for certain checkpoints but not for others"
+ )
+ def test_save_load_optional_components(self):
+ pass
diff --git a/tests/pipelines/wan/test_wan_video_to_video.py b/tests/pipelines/wan/test_wan_video_to_video.py
index 11c748424a..f4bb0960ac 100644
--- a/tests/pipelines/wan/test_wan_video_to_video.py
+++ b/tests/pipelines/wan/test_wan_video_to_video.py
@@ -14,7 +14,6 @@
import unittest
-import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
@@ -123,11 +122,15 @@ class WanVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
-
self.assertEqual(generated_video.shape, (17, 3, 16, 16))
- expected_video = torch.randn(17, 3, 16, 16)
- max_diff = np.abs(generated_video - expected_video).max()
- self.assertLessEqual(max_diff, 1e10)
+
+ # fmt: off
+ expected_slice = torch.tensor([0.4522, 0.4534, 0.4532, 0.4553, 0.4526, 0.4538, 0.4533, 0.4547, 0.513, 0.5176, 0.5286, 0.4958, 0.4955, 0.5381, 0.5154, 0.5195])
+ # fmt:on
+
+ generated_slice = generated_video.flatten()
+ generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
+ self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
diff --git a/tests/pipelines/wuerstchen/__init__.py b/tests/pipelines/wuerstchen/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
deleted file mode 100644
index 060a11434e..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
+++ /dev/null
@@ -1,241 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
-from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenCombinedPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "generator",
- "height",
- "width",
- "latents",
- "prior_guidance_scale",
- "decoder_guidance_scale",
- "negative_prompt",
- "num_inference_steps",
- "return_dict",
- "prior_num_inference_steps",
- "output_type",
- ]
- test_xformers_attention = True
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2}
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_prior_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- prior_text_encoder = self.dummy_prior_text_encoder
-
- scheduler = DDPMWuerstchenScheduler()
- tokenizer = self.dummy_tokenizer
-
- text_encoder = self.dummy_text_encoder
- decoder = self.dummy_decoder
- vqgan = self.dummy_vqgan
-
- components = {
- "tokenizer": tokenizer,
- "text_encoder": text_encoder,
- "decoder": decoder,
- "vqgan": vqgan,
- "scheduler": scheduler,
- "prior_prior": prior,
- "prior_text_encoder": prior_text_encoder,
- "prior_tokenizer": tokenizer,
- "prior_scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "prior_guidance_scale": 4.0,
- "decoder_guidance_scale": 4.0,
- "num_inference_steps": 2,
- "prior_num_inference_steps": 2,
- "output_type": "np",
- "height": 128,
- "width": 128,
- }
- return inputs
-
- def test_wuerstchen(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
-
- assert image.shape == (1, 128, 128, 3)
-
- expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
-
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
- )
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2, (
- f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
- )
-
- @require_torch_accelerator
- def test_offloads(self):
- pipes = []
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components).to(torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_sequential_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- components = self.get_dummy_components()
- sd_pipe = self.pipeline_class(**components)
- sd_pipe.enable_model_cpu_offload(device=torch_device)
- pipes.append(sd_pipe)
-
- image_slices = []
- for pipe in pipes:
- inputs = self.get_dummy_inputs(torch_device)
- image = pipe(**inputs).images
-
- image_slices.append(image[0, -3:, -3:, -1].flatten())
-
- assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
- assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
-
- def test_inference_batch_single_identical(self):
- super().test_inference_batch_single_identical(expected_max_diff=1e-2)
-
- @unittest.skip(reason="flakey and float16 requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_inputs(self):
- pass
-
- @unittest.skip(reason="Test not supported.")
- def test_callback_cfg(self):
- pass
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
deleted file mode 100644
index 5d2462d48d..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
-from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
-from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenDecoderPipeline
- params = ["prompt"]
- batch_params = ["image_embeddings", "prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- projection_dim=self.text_embedder_hidden_size,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_vqgan(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "bottleneck_blocks": 1,
- "num_vq_embeddings": 2,
- }
- model = PaellaVQModel(**model_kwargs)
- return model.eval()
-
- @property
- def dummy_decoder(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_cond": self.text_embedder_hidden_size,
- "c_hidden": [320],
- "nhead": [-1],
- "blocks": [4],
- "level_config": ["CT"],
- "clip_embd": self.text_embedder_hidden_size,
- "inject_effnet": [False],
- }
-
- model = WuerstchenDiffNeXt(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- decoder = self.dummy_decoder
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
- vqgan = self.dummy_vqgan
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "decoder": decoder,
- "vqgan": vqgan,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- "latent_dim_scale": 4.0,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "image_embeddings": torch.ones((1, 4, 4, 4), device=device),
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 1.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_decoder(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.images
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
-
- image_slice = image[0, -3:, -3:, -1]
- image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
-
- assert image.shape == (1, 64, 64, 3)
-
- expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000])
- assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(expected_max_diff=1e-5)
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="bf16 not supported and requires CUDA")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- @unittest.skip("Test not supported.")
- def test_encode_prompt_works_in_isolation(self):
- super().test_encode_prompt_works_in_isolation()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
deleted file mode 100644
index 34f7c684b7..0000000000
--- a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
+++ /dev/null
@@ -1,273 +0,0 @@
-# coding=utf-8
-# Copyright 2025 HuggingFace Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import unittest
-
-import numpy as np
-import torch
-from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
-
-from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
-from diffusers.pipelines.wuerstchen import WuerstchenPrior
-from diffusers.utils.import_utils import is_peft_available
-from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
-
-
-if is_peft_available():
- from peft import LoraConfig
- from peft.tuners.tuners_utils import BaseTunerLayer
-
-from ..test_pipelines_common import PipelineTesterMixin
-
-
-enable_full_determinism()
-
-
-class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
- pipeline_class = WuerstchenPriorPipeline
- params = ["prompt"]
- batch_params = ["prompt", "negative_prompt"]
- required_optional_params = [
- "num_images_per_prompt",
- "generator",
- "num_inference_steps",
- "latents",
- "negative_prompt",
- "guidance_scale",
- "output_type",
- "return_dict",
- ]
- test_xformers_attention = False
- callback_cfg_params = ["text_encoder_hidden_states"]
-
- @property
- def text_embedder_hidden_size(self):
- return 32
-
- @property
- def time_input_dim(self):
- return 32
-
- @property
- def block_out_channels_0(self):
- return self.time_input_dim
-
- @property
- def time_embed_dim(self):
- return self.time_input_dim * 4
-
- @property
- def dummy_tokenizer(self):
- tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
- return tokenizer
-
- @property
- def dummy_text_encoder(self):
- torch.manual_seed(0)
- config = CLIPTextConfig(
- bos_token_id=0,
- eos_token_id=2,
- hidden_size=self.text_embedder_hidden_size,
- intermediate_size=37,
- layer_norm_eps=1e-05,
- num_attention_heads=4,
- num_hidden_layers=5,
- pad_token_id=1,
- vocab_size=1000,
- )
- return CLIPTextModel(config).eval()
-
- @property
- def dummy_prior(self):
- torch.manual_seed(0)
-
- model_kwargs = {
- "c_in": 2,
- "c": 8,
- "depth": 2,
- "c_cond": 32,
- "c_r": 8,
- "nhead": 2,
- }
-
- model = WuerstchenPrior(**model_kwargs)
- return model.eval()
-
- def get_dummy_components(self):
- prior = self.dummy_prior
- text_encoder = self.dummy_text_encoder
- tokenizer = self.dummy_tokenizer
-
- scheduler = DDPMWuerstchenScheduler()
-
- components = {
- "prior": prior,
- "text_encoder": text_encoder,
- "tokenizer": tokenizer,
- "scheduler": scheduler,
- }
-
- return components
-
- def get_dummy_inputs(self, device, seed=0):
- if str(device).startswith("mps"):
- generator = torch.manual_seed(seed)
- else:
- generator = torch.Generator(device=device).manual_seed(seed)
- inputs = {
- "prompt": "horse",
- "generator": generator,
- "guidance_scale": 4.0,
- "num_inference_steps": 2,
- "output_type": "np",
- }
- return inputs
-
- def test_wuerstchen_prior(self):
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output = pipe(**self.get_dummy_inputs(device))
- image = output.image_embeddings
-
- image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
-
- image_slice = image[0, 0, 0, -10:]
- image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
- assert image.shape == (1, 2, 24, 24)
-
- expected_slice = np.array(
- [
- -7172.837,
- -3438.855,
- -1093.312,
- 388.8835,
- -7471.467,
- -7998.1206,
- -5328.259,
- 218.00089,
- -2731.5745,
- -8056.734,
- ]
- )
- assert np.abs(image_slice.flatten() - expected_slice).max() < 5e-2
- assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 5e-2
-
- @skip_mps
- def test_inference_batch_single_identical(self):
- self._test_inference_batch_single_identical(
- expected_max_diff=3e-1,
- )
-
- @skip_mps
- def test_attention_slicing_forward_pass(self):
- test_max_difference = torch_device == "cpu"
- test_mean_pixel_difference = False
-
- self._test_attention_slicing_forward_pass(
- test_max_difference=test_max_difference,
- test_mean_pixel_difference=test_mean_pixel_difference,
- )
-
- @unittest.skip(reason="flaky for now")
- def test_float16_inference(self):
- super().test_float16_inference()
-
- # override because we need to make sure latent_mean and latent_std to be 0
- def test_callback_inputs(self):
- components = self.get_dummy_components()
- components["latent_mean"] = 0
- components["latent_std"] = 0
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(torch_device)
- pipe.set_progress_bar_config(disable=None)
-
- self.assertTrue(
- hasattr(pipe, "_callback_tensor_inputs"),
- f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
- )
-
- def callback_inputs_test(pipe, i, t, callback_kwargs):
- missing_callback_inputs = set()
- for v in pipe._callback_tensor_inputs:
- if v not in callback_kwargs:
- missing_callback_inputs.add(v)
- self.assertTrue(
- len(missing_callback_inputs) == 0, f"Missing callback tensor inputs: {missing_callback_inputs}"
- )
- last_i = pipe.num_timesteps - 1
- if i == last_i:
- callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
- return callback_kwargs
-
- inputs = self.get_dummy_inputs(torch_device)
- inputs["callback_on_step_end"] = callback_inputs_test
- inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
- inputs["output_type"] = "latent"
-
- output = pipe(**inputs)[0]
- assert output.abs().sum() == 0
-
- def check_if_lora_correctly_set(self, model) -> bool:
- """
- Checks if the LoRA layers are correctly set with peft
- """
- for module in model.modules():
- if isinstance(module, BaseTunerLayer):
- return True
- return False
-
- def get_lora_components(self):
- prior = self.dummy_prior
-
- prior_lora_config = LoraConfig(
- r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
- )
-
- return prior, prior_lora_config
-
- @require_peft_backend
- def test_inference_with_prior_lora(self):
- _, prior_lora_config = self.get_lora_components()
- device = "cpu"
-
- components = self.get_dummy_components()
-
- pipe = self.pipeline_class(**components)
- pipe = pipe.to(device)
-
- pipe.set_progress_bar_config(disable=None)
-
- output_no_lora = pipe(**self.get_dummy_inputs(device))
- image_embed = output_no_lora.image_embeddings
- self.assertTrue(image_embed.shape == (1, 2, 24, 24))
-
- pipe.prior.add_adapter(prior_lora_config)
- self.assertTrue(self.check_if_lora_correctly_set(pipe.prior), "Lora not correctly set in prior")
-
- output_lora = pipe(**self.get_dummy_inputs(device))
- lora_image_embed = output_lora.image_embeddings
-
- self.assertTrue(image_embed.shape == lora_image_embed.shape)
-
- @unittest.skip("Test not supported as dtype cannot be inferred without the text encoder otherwise.")
- def test_encode_prompt_works_in_isolation(self):
- pass
diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py
index bdb8920a39..8e2a8515c6 100644
--- a/tests/quantization/bnb/test_4bit.py
+++ b/tests/quantization/bnb/test_4bit.py
@@ -98,7 +98,14 @@ class Base4bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -865,23 +872,23 @@ class ExtendedSerializationTest(BaseBnb4BitSerializationTests):
@require_torch_version_greater("2.7.1")
-class Bnb4BitCompileTests(QuantCompileTests):
- quantization_config = PipelineQuantizationConfig(
- quant_backend="bitsandbytes_8bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer", "text_encoder_2"],
- )
+@require_bitsandbytes_version_greater("0.45.5")
+class Bnb4BitCompileTests(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16,
+ },
+ components_to_quantize=["transformer", "text_encoder_2"],
+ )
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
- super()._test_torch_compile(quantization_config=self.quantization_config)
+ super().test_torch_compile()
- def test_torch_compile_with_cpu_offload(self):
- super()._test_torch_compile_with_cpu_offload(quantization_config=self.quantization_config)
-
- def test_torch_compile_with_group_offload(self):
- super()._test_torch_compile_with_group_offload(quantization_config=self.quantization_config)
+ def test_torch_compile_with_group_offload_leaf(self):
+ super()._test_torch_compile_with_group_offload_leaf(use_stream=True)
diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py
index d048b0b7db..64f56b02b0 100644
--- a/tests/quantization/bnb/test_mixed_int8.py
+++ b/tests/quantization/bnb/test_mixed_int8.py
@@ -99,7 +99,14 @@ class Base8bitTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
- torch.use_deterministic_algorithms(True)
+ cls.is_deterministic_enabled = torch.are_deterministic_algorithms_enabled()
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if not cls.is_deterministic_enabled:
+ torch.use_deterministic_algorithms(False)
def get_dummy_inputs(self):
prompt_embeds = load_pt(
@@ -830,24 +837,23 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
@require_torch_version_greater_equal("2.6.0")
-class Bnb8BitCompileTests(QuantCompileTests):
- quantization_config = PipelineQuantizationConfig(
- quant_backend="bitsandbytes_8bit",
- quant_kwargs={"load_in_8bit": True},
- components_to_quantize=["transformer", "text_encoder_2"],
- )
+@require_bitsandbytes_version_greater("0.45.5")
+class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=["transformer", "text_encoder_2"],
+ )
def test_torch_compile(self):
torch._dynamo.config.capture_dynamic_output_shape_ops = True
- super()._test_torch_compile(quantization_config=self.quantization_config, torch_dtype=torch.float16)
+ super()._test_torch_compile(torch_dtype=torch.float16)
def test_torch_compile_with_cpu_offload(self):
- super()._test_torch_compile_with_cpu_offload(
- quantization_config=self.quantization_config, torch_dtype=torch.float16
- )
+ super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
- def test_torch_compile_with_group_offload(self):
- super()._test_torch_compile_with_group_offload(
- quantization_config=self.quantization_config, torch_dtype=torch.float16
- )
+ def test_torch_compile_with_group_offload_leaf(self):
+ super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
diff --git a/tests/pipelines/blipdiffusion/__init__.py b/tests/quantization/gguf/__init__.py
similarity index 100%
rename from tests/pipelines/blipdiffusion/__init__.py
rename to tests/quantization/gguf/__init__.py
diff --git a/tests/quantization/gguf/test_gguf.py b/tests/quantization/gguf/test_gguf.py
index ae3900459d..ba41678eaa 100644
--- a/tests/quantization/gguf/test_gguf.py
+++ b/tests/quantization/gguf/test_gguf.py
@@ -8,6 +8,7 @@ import torch.nn as nn
from diffusers import (
AuraFlowPipeline,
AuraFlowTransformer2DModel,
+ DiffusionPipeline,
FluxControlPipeline,
FluxPipeline,
FluxTransformer2DModel,
@@ -15,6 +16,8 @@ from diffusers import (
HiDreamImageTransformer2DModel,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
)
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
@@ -30,9 +33,12 @@ from diffusers.utils.testing_utils import (
require_big_accelerator,
require_gguf_version_greater_or_equal,
require_peft_backend,
+ require_torch_version_greater,
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
+
if is_gguf_available():
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
@@ -286,33 +292,33 @@ class SD35LargeGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase)
{
("xpu", 3): np.array(
[
- 0.19335938,
- 0.3125,
- 0.3203125,
- 0.1328125,
- 0.3046875,
- 0.296875,
- 0.11914062,
- 0.2890625,
- 0.2890625,
- 0.16796875,
- 0.30273438,
- 0.33203125,
- 0.14648438,
- 0.31640625,
- 0.33007812,
+ 0.16210938,
+ 0.2734375,
+ 0.27734375,
+ 0.109375,
+ 0.27148438,
+ 0.2578125,
+ 0.1015625,
+ 0.2578125,
+ 0.2578125,
+ 0.14453125,
+ 0.26953125,
+ 0.29492188,
0.12890625,
- 0.3046875,
- 0.30859375,
- 0.17773438,
- 0.33789062,
- 0.33203125,
- 0.16796875,
- 0.34570312,
- 0.32421875,
+ 0.28710938,
+ 0.30078125,
+ 0.11132812,
+ 0.27734375,
+ 0.27929688,
0.15625,
- 0.33203125,
- 0.31445312,
+ 0.31054688,
+ 0.296875,
+ 0.15234375,
+ 0.3203125,
+ 0.29492188,
+ 0.140625,
+ 0.3046875,
+ 0.28515625,
]
),
("cuda", 7): np.array(
@@ -577,3 +583,90 @@ class HiDreamGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
).to(torch_device, self.torch_dtype),
"timesteps": torch.tensor([1]).to(torch_device, self.torch_dtype),
}
+
+
+class WanGGUFTexttoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-T2V-14B-gguf/blob/main/wan2.1-t2v-14b-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanGGUFImagetoVideoSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/city96/Wan2.1-I2V-14B-480P-gguf/blob/main/wan2.1-i2v-14b-480p-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanTransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 36, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "encoder_hidden_states_image": torch.randn(
+ (1, 257, 1280), generator=torch.Generator("cpu").manual_seed(0)
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+class WanVACEGGUFSingleFileTests(GGUFSingleFileTesterMixin, unittest.TestCase):
+ ckpt_path = "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
+ torch_dtype = torch.bfloat16
+ model_cls = WanVACETransformer3DModel
+ expected_memory_use_in_gb = 9
+
+ def get_dummy_inputs(self):
+ return {
+ "hidden_states": torch.randn((1, 16, 2, 64, 64), generator=torch.Generator("cpu").manual_seed(0)).to(
+ torch_device, self.torch_dtype
+ ),
+ "encoder_hidden_states": torch.randn(
+ (1, 512, 4096),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states": torch.randn(
+ (1, 96, 2, 64, 64),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "control_hidden_states_scale": torch.randn(
+ (8,),
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).to(torch_device, self.torch_dtype),
+ "timestep": torch.tensor([1]).to(torch_device, self.torch_dtype),
+ }
+
+
+@require_torch_version_greater("2.7.1")
+class GGUFCompileTests(QuantCompileTests, unittest.TestCase):
+ torch_dtype = torch.bfloat16
+ gguf_ckpt = "https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
+
+ @property
+ def quantization_config(self):
+ return GGUFQuantizationConfig(compute_dtype=self.torch_dtype)
+
+ def _init_pipeline(self, *args, **kwargs):
+ transformer = FluxTransformer2DModel.from_single_file(
+ self.gguf_ckpt, quantization_config=self.quantization_config, torch_dtype=self.torch_dtype
+ )
+ pipe = DiffusionPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=self.torch_dtype
+ )
+ return pipe
diff --git a/tests/quantization/test_pipeline_level_quantization.py b/tests/quantization/test_pipeline_level_quantization.py
index 5a724df5c3..e91fe6d4cb 100644
--- a/tests/quantization/test_pipeline_level_quantization.py
+++ b/tests/quantization/test_pipeline_level_quantization.py
@@ -12,13 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import tempfile
import unittest
import torch
from parameterized import parameterized
-from diffusers import DiffusionPipeline, QuantoConfig
+from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils import logging
from diffusers.utils.testing_utils import (
@@ -243,3 +244,57 @@ class PipelineQuantizationTests(unittest.TestCase):
for name, component in pipe.components.items():
if isinstance(component, torch.nn.Module):
self.assertTrue(not hasattr(component.config, "quantization_config"))
+
+ @parameterized.expand(["quant_kwargs", "quant_mapping"])
+ def test_quant_config_repr(self, method):
+ component_name = "transformer"
+ if method == "quant_kwargs":
+ components_to_quantize = [component_name]
+ quant_config = PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_8bit",
+ quant_kwargs={"load_in_8bit": True},
+ components_to_quantize=components_to_quantize,
+ )
+ else:
+ quant_config = PipelineQuantizationConfig(
+ quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)}
+ )
+
+ pipe = DiffusionPipeline.from_pretrained(
+ self.model_name,
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ )
+ self.assertTrue(getattr(pipe, "quantization_config", None) is not None)
+ retrieved_config = pipe.quantization_config
+ expected_config = """
+transformer BitsAndBytesConfig {
+ "_load_in_4bit": false,
+ "_load_in_8bit": true,
+ "bnb_4bit_compute_dtype": "float32",
+ "bnb_4bit_quant_storage": "uint8",
+ "bnb_4bit_quant_type": "fp4",
+ "bnb_4bit_use_double_quant": false,
+ "llm_int8_enable_fp32_cpu_offload": false,
+ "llm_int8_has_fp16_weight": false,
+ "llm_int8_skip_modules": null,
+ "llm_int8_threshold": 6.0,
+ "load_in_4bit": false,
+ "load_in_8bit": true,
+ "quant_method": "bitsandbytes"
+}
+
+"""
+ expected_data = self._parse_config_string(expected_config)
+ actual_data = self._parse_config_string(str(retrieved_config))
+ self.assertTrue(actual_data == expected_data)
+
+ def _parse_config_string(self, config_string: str) -> tuple[str, dict]:
+ first_brace = config_string.find("{")
+ if first_brace == -1:
+ raise ValueError("Could not find opening brace '{' in the string.")
+
+ json_part = config_string[first_brace:]
+ data = json.loads(json_part)
+
+ return data
diff --git a/tests/quantization/test_torch_compile_utils.py b/tests/quantization/test_torch_compile_utils.py
index ba870ba733..c742927646 100644
--- a/tests/quantization/test_torch_compile_utils.py
+++ b/tests/quantization/test_torch_compile_utils.py
@@ -13,18 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
-import unittest
+import inspect
import torch
from diffusers import DiffusionPipeline
-from diffusers.utils.testing_utils import backend_empty_cache, require_torch_gpu, slow, torch_device
+from diffusers.utils.testing_utils import backend_empty_cache, require_torch_accelerator, slow, torch_device
-@require_torch_gpu
+@require_torch_accelerator
@slow
-class QuantCompileTests(unittest.TestCase):
- quantization_config = None
+class QuantCompileTests:
+ @property
+ def quantization_config(self):
+ raise NotImplementedError(
+ "This property should be implemented in the subclass to return the appropriate quantization config."
+ )
def setUp(self):
super().setUp()
@@ -46,42 +50,50 @@ class QuantCompileTests(unittest.TestCase):
)
return pipe
- def _test_torch_compile(self, quantization_config, torch_dtype=torch.bfloat16):
- pipe = self._init_pipeline(quantization_config, torch_dtype).to("cuda")
- # import to ensure fullgraph True
+ def _test_torch_compile(self, torch_dtype=torch.bfloat16):
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype).to(torch_device)
+ # `fullgraph=True` ensures no graph breaks
pipe.transformer.compile(fullgraph=True)
- for _ in range(2):
- # small resolutions to ensure speedy execution.
- pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
+ # small resolutions to ensure speedy execution.
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
- def _test_torch_compile_with_cpu_offload(self, quantization_config, torch_dtype=torch.bfloat16):
- pipe = self._init_pipeline(quantization_config, torch_dtype)
+ def _test_torch_compile_with_cpu_offload(self, torch_dtype=torch.bfloat16):
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype)
pipe.enable_model_cpu_offload()
pipe.transformer.compile()
- for _ in range(2):
- # small resolutions to ensure speedy execution.
- pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
+ # small resolutions to ensure speedy execution.
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
- def _test_torch_compile_with_group_offload(self, quantization_config, torch_dtype=torch.bfloat16):
- torch._dynamo.config.cache_size_limit = 10000
+ def _test_torch_compile_with_group_offload_leaf(self, torch_dtype=torch.bfloat16, *, use_stream: bool = False):
+ torch._dynamo.config.cache_size_limit = 1000
- pipe = self._init_pipeline(quantization_config, torch_dtype)
+ pipe = self._init_pipeline(self.quantization_config, torch_dtype)
group_offload_kwargs = {
- "onload_device": torch.device("cuda"),
+ "onload_device": torch.device(torch_device),
"offload_device": torch.device("cpu"),
"offload_type": "leaf_level",
- "use_stream": True,
- "non_blocking": True,
+ "use_stream": use_stream,
}
pipe.transformer.enable_group_offload(**group_offload_kwargs)
pipe.transformer.compile()
for name, component in pipe.components.items():
if name != "transformer" and isinstance(component, torch.nn.Module):
if torch.device(component.device).type == "cpu":
- component.to("cuda")
+ component.to(torch_device)
- for _ in range(2):
- # small resolutions to ensure speedy execution.
- pipe("a dog", num_inference_steps=3, max_sequence_length=16, height=256, width=256)
+ # small resolutions to ensure speedy execution.
+ pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
+
+ def test_torch_compile(self):
+ self._test_torch_compile()
+
+ def test_torch_compile_with_cpu_offload(self):
+ self._test_torch_compile_with_cpu_offload()
+
+ def test_torch_compile_with_group_offload_leaf(self, use_stream=False):
+ for cls in inspect.getmro(self.__class__):
+ if "test_torch_compile_with_group_offload_leaf" in cls.__dict__ and cls is not QuantCompileTests:
+ return
+ self._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
diff --git a/tests/quantization/torchao/test_torchao.py b/tests/quantization/torchao/test_torchao.py
index 0741c7f87c..5dcc207e65 100644
--- a/tests/quantization/torchao/test_torchao.py
+++ b/tests/quantization/torchao/test_torchao.py
@@ -19,6 +19,7 @@ import unittest
from typing import List
import numpy as np
+from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
@@ -29,6 +30,7 @@ from diffusers import (
TorchAoConfig,
)
from diffusers.models.attention_processor import Attention
+from diffusers.quantizers import PipelineQuantizationConfig
from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_synchronize,
@@ -44,6 +46,8 @@ from diffusers.utils.testing_utils import (
torch_device,
)
+from ..test_torch_compile_utils import QuantCompileTests
+
enable_full_determinism()
@@ -232,7 +236,7 @@ class TorchAoTest(unittest.TestCase):
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])),
]
- if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])),
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])),
@@ -625,6 +629,50 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
+@require_torchao_version_greater_or_equal("0.7.0")
+class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
+ @property
+ def quantization_config(self):
+ return PipelineQuantizationConfig(
+ quant_mapping={
+ "transformer": TorchAoConfig(quant_type="int8_weight_only"),
+ },
+ )
+
+ @unittest.skip(
+ "Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
+ "when compiling."
+ )
+ def test_torch_compile_with_cpu_offload(self):
+ # RuntimeError: _apply(): Couldn't swap Linear.weight
+ super().test_torch_compile_with_cpu_offload()
+
+ @parameterized.expand([False, True])
+ @unittest.skip(
+ """
+ For `use_stream=False`:
+ - Changing the device of AQT tensor, with `param.data = param.data.to(device)` as done in group offloading implementation
+ is unsupported in TorchAO. When compiling, FakeTensor device mismatch causes failure.
+ For `use_stream=True`:
+ Using non-default stream requires ability to pin tensors. AQT does not seem to support this yet in TorchAO.
+ """
+ )
+ def test_torch_compile_with_group_offload_leaf(self, use_stream):
+ # For use_stream=False:
+ # If we run group offloading without compilation, we will see:
+ # RuntimeError: Attempted to set the storage of a tensor on device "cpu" to a storage on different device "cuda:0". This is no longer allowed; the devices must match.
+ # When running with compilation, the error ends up being different:
+ # Dynamo failed to run FX node with fake tensors: call_function (*(FakeTensor(..., device='cuda:0', size=(s0, 256), dtype=torch.bfloat16), AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=FakeTensor(..., size=(1536, 256), dtype=torch.int8)... , scale=FakeTensor(..., size=(1536,), dtype=torch.bfloat16)... , zero_point=FakeTensor(..., size=(1536,), dtype=torch.int64)... , _layout=PlainLayout()), block_size=(1, 256), shape=torch.Size([1536, 256]), device=cpu, dtype=torch.bfloat16, requires_grad=False), Parameter(FakeTensor(..., device='cuda:0', size=(1536,), dtype=torch.bfloat16,
+ # requires_grad=True))), **{}): got RuntimeError('Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu')
+ # Looks like something that will have to be looked into upstream.
+ # for linear layers, weight.tensor_impl shows cuda... but:
+ # weight.tensor_impl.{data,scale,zero_point}.device will be cpu
+
+ # For use_stream=True:
+ # NotImplementedError: AffineQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=, types=(,), arg_types=(,), kwarg_types={}
+ super()._test_torch_compile_with_group_offload_leaf(use_stream=use_stream)
+
+
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@@ -705,7 +753,7 @@ class SlowTorchAoTests(unittest.TestCase):
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])),
]
- if TorchAoConfig._is_cuda_capability_atleast_8_9():
+ if TorchAoConfig._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES_TO_TEST.extend([
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])),
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])),
diff --git a/utils/print_env.py b/utils/print_env.py
index 2d2acb59d5..2fe0777daf 100644
--- a/utils/print_env.py
+++ b/utils/print_env.py
@@ -28,6 +28,16 @@ print("Python version:", sys.version)
print("OS platform:", platform.platform())
print("OS architecture:", platform.machine())
+try:
+ import psutil
+
+ vm = psutil.virtual_memory()
+ total_gb = vm.total / (1024**3)
+ available_gb = vm.available / (1024**3)
+ print(f"Total RAM: {total_gb:.2f} GB")
+ print(f"Available RAM: {available_gb:.2f} GB")
+except ImportError:
+ pass
try:
import torch